diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 5dd663db..ff9556c5 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -260,60 +260,59 @@ pub struct ModelLoad { pub num_ctx_tokens: usize, } impl ModelLoad { - pub fn load(&self) -> (llama_rs::Model, llama_rs::Vocabulary) { - let (model, vocabulary) = - llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| { - use llama_rs::LoadProgress; - match progress { - LoadProgress::HyperparametersLoaded(hparams) => { - log::debug!("Loaded hyperparameters {hparams:#?}") - } - LoadProgress::ContextSize { bytes } => log::info!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::PartLoading { - file, + pub fn load(&self) -> llama_rs::Model { + let model = llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| { + use llama_rs::LoadProgress; + match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded hyperparameters {hparams:#?}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => { + let current_part = current_part + 1; + log::info!( + "Loading model part {}/{} from '{}'\n", current_part, total_parts, - } => { - let current_part = current_part + 1; - log::info!( - "Loading model part {}/{} from '{}'\n", - current_part, - total_parts, - file.to_string_lossy(), - ) - } - LoadProgress::PartTensorLoaded { - current_tensor, - tensor_count, - .. - } => { - let current_tensor = current_tensor + 1; - if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{tensor_count}"); - } - } - LoadProgress::PartLoaded { - file, - byte_size, - tensor_count, - } => { - log::info!("Loading of '{}' complete", file.to_string_lossy()); - log::info!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); + file.to_string_lossy(), + ) + } + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + let current_tensor = current_tensor + 1; + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); } } - }) - .expect("Could not load model"); + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + } + }) + .expect("Could not load model"); log::info!("Model fully loaded!"); - (model, vocabulary) + model } } diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 63c8b0fe..f0f072b8 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -27,7 +27,7 @@ fn main() { fn infer(args: &cli_args::Infer) { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_params = args.generate.inference_session_parameters(); - let (model, vocabulary) = args.model_load.load(); + let model = args.model_load.load(); let (mut session, session_loaded) = snapshot::read_or_create_session( &model, args.persist_session.as_deref(), @@ -39,7 +39,6 @@ fn infer(args: &cli_args::Infer) { let mut rng = args.generate.rng(); let res = session.inference_with_prompt::( &model, - &vocabulary, &inference_params, &prompt, args.generate.num_predict, @@ -73,8 +72,8 @@ fn infer(args: &cli_args::Infer) { fn dump_tokens(args: &cli_args::DumpTokens) { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); - let (_, vocabulary) = args.model_load.load(); - let toks = match vocabulary.tokenize(&prompt, false) { + let model = args.model_load.load(); + let toks = match model.vocabulary().tokenize(&prompt, false) { Ok(toks) => toks, Err(e) => { log::error!("Could not tokenize prompt: {e}"); @@ -106,7 +105,7 @@ fn interactive( ) { let prompt_file = args.prompt_file.contents(); let inference_session_params = args.generate.inference_session_parameters(); - let (model, vocabulary) = args.model_load.load(); + let model = args.model_load.load(); let (mut session, session_loaded) = snapshot::read_or_create_session( &model, None, @@ -135,7 +134,6 @@ fn interactive( let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); if let Err(InferenceError::ContextFull) = session.feed_prompt::( &model, - &vocabulary, &inference_params, &prompt, |_| Ok(()), @@ -146,7 +144,6 @@ fn interactive( let res = session.inference_with_prompt::( &model, - &vocabulary, &inference_params, "", args.generate.num_predict, diff --git a/llama-cli/src/snapshot.rs b/llama-cli/src/snapshot.rs index 5107aebc..3601de76 100644 --- a/llama-cli/src/snapshot.rs +++ b/llama-cli/src/snapshot.rs @@ -26,7 +26,7 @@ pub fn read_or_create_session( let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || { format!("Could not deserialize inference session from {path:?}") }); - let session = unwrap_or_exit(model.session_from_snapshot(snapshot), || { + let session = unwrap_or_exit(InferenceSession::from_snapshot(snapshot, model), || { format!("Could not convert snapshot from {path:?} to session") }); log::info!("Loaded inference session from {path:?}"); diff --git a/llama-rs/src/inference_session.rs b/llama-rs/src/inference_session.rs new file mode 100644 index 00000000..3af27812 --- /dev/null +++ b/llama-rs/src/inference_session.rs @@ -0,0 +1,574 @@ +use std::fmt::Display; + +use partial_sort::PartialSort; +use rand::{distributions::WeightedIndex, prelude::Distribution}; +use thiserror::Error; + +use crate::{ + util::mulf, EvaluateOutputRequest, InferenceError, InferenceParameters, Model, TokenId, + TokenUtf8Buffer, EOT_TOKEN_ID, +}; + +// The size of a scratch buffer used for inference. This is used for temporary +// storage of intermediate results during inference. +// +// The specific value was copied from `llama.cpp`. +const SCRATCH_SIZE: usize = 512 * 1024 * 1024; + +/// An inference session represents the state of the text generation. This holds +/// the full context window, as long as several additional parameters used +/// during sampling. +pub struct InferenceSession { + // Must be kept alive for the model + pub(crate) _session_ctx: ggml::Context, + + // Original size of the memory used to create this context. + pub(crate) memory_size: usize, + + // Parameters for the session. + pub(crate) params: InferenceSessionParameters, + + pub(crate) memory_k: ggml::Tensor, + pub(crate) memory_v: ggml::Tensor, + + /// How many tokens have been fed into the model's working memory so far. + pub(crate) n_past: usize, + + /// How much memory is required per token for the temporary context used + /// during inference. + pub(crate) mem_per_token: usize, + + /// All tokens generated by this inference session + pub(crate) tokens: Vec, + + /// The logits that were last predicted by the network. Zeroed out otherwise. + pub(crate) last_logits: Vec, + + /// Scratch buffers used during inference. + /// + /// The number of scratch buffers was copied from `llama.cpp`. + /// There is no specific reason for this number, but one is insufficient. + pub(crate) scratch: [ggml::Buffer; 2], +} +impl InferenceSession { + /// Feed a prompt to the model for this session. + pub fn feed_prompt( + &mut self, + model: &Model, + params: &InferenceParameters, + prompt: &str, + mut callback: impl FnMut(&[u8]) -> Result<(), E>, + ) -> Result<(), InferenceError> { + let beginning_of_sentence = self.n_past == 0; + + let vocab = model.vocabulary(); + let prompt_tokens: Vec = vocab + .tokenize(prompt, beginning_of_sentence)? + .iter() + .map(|(_, tok)| *tok) + .collect(); + + if self.n_past + prompt_tokens.len() >= model.hparams.n_ctx { + return Err(InferenceError::ContextFull); + } + + for batch in prompt_tokens.chunks(params.n_batch) { + model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default()); + for &tk in batch { + // NOTE: No string ever tokenizes to the end of sentence. So we + // can just return the id here. + if let Err(e) = callback(vocab.token(tk as usize)) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + + // Update the tokens for this session + self.tokens.push(tk); + } + } + + Ok(()) + } + + /// Infer the next token for this session. + pub fn infer_next_token<'v>( + &mut self, + model: &'v Model, + params: &InferenceParameters, + rng: &mut impl rand::Rng, + ) -> Result<&'v [u8], InferenceError> { + if self.n_past + 1 >= model.hparams.n_ctx { + return Err(InferenceError::ContextFull); + } + + // First, sample the next token, using the stored last_logits; + let next_token = self.sample_top_p_top_k(params, rng); + + // Update the tokens for this session + self.tokens.push(next_token); + + // Then, evaluate the network again to compute the new last_logits + model.evaluate( + self, + params, + &[next_token], + &mut EvaluateOutputRequest::default(), + ); + + // Return the next token + if next_token as TokenId == EOT_TOKEN_ID { + Err(InferenceError::EndOfText) + } else { + Ok(model.vocabulary().token(next_token as usize)) + } + } + + // todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe? + /// Helper function to run inference with this session and the given model and vocabulary. + /// The `callback` is called with each new token until inference is complete. + /// + /// If `params.play_back_previous_tokens` is specified, this will "play back" all existing tokens in the session. + pub fn inference_with_prompt( + &mut self, + model: &Model, + params: &InferenceParameters, + prompt: &str, + maximum_token_count: Option, + rng: &mut impl rand::Rng, + mut callback: impl FnMut(&str) -> Result<(), E>, + ) -> Result { + let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX); + if params.play_back_previous_tokens { + // "Play back" the existing tokens, so that loading from an inference snapshot works + // as expected. + let mut token_utf8_buf = TokenUtf8Buffer::new(); + for token_id in &self.tokens { + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = + token_utf8_buf.push(model.vocabulary().token(*token_id as usize)) + { + if let Err(e) = callback(&tokens) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + } + } + } + + let mut stats = InferenceStats::default(); + + let start_at = std::time::SystemTime::now(); + + // Feed the initial prompt through the transformer, to update its + // context window with new data. + self.feed_prompt( + model, + params, + prompt, + TokenUtf8Buffer::adapt_callback(&mut callback), + )?; + stats.feed_prompt_duration = start_at.elapsed().unwrap(); + stats.prompt_tokens = self.n_past; + + // After the prompt is consumed, sample tokens by repeatedly calling + // `infer_next_token`. We generate tokens until the model returns an + // EndOfText token, or we run out of space in the context window, + // or we reach the specified limit. + let mut tokens_processed = 0; + let mut token_utf8_buf = TokenUtf8Buffer::new(); + while tokens_processed < maximum_token_count { + let token = match self.infer_next_token(model, params, rng) { + Ok(token) => token, + Err(InferenceError::EndOfText) => break, + Err(e) => return Err(e), + }; + + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(token) { + if let Err(e) = callback(&tokens) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + } + + tokens_processed += 1; + } + stats.predict_duration = start_at.elapsed().unwrap(); + stats.predict_tokens = self.n_past; + + Ok(stats) + } + + /// Sample a token using Top-P/Top-K sampling and the last logits from this session. + pub fn sample_top_p_top_k( + &self, + params: &InferenceParameters, + rng: &mut impl rand::Rng, + ) -> TokenId { + let logits = &self.last_logits; + let n_logits = logits.len(); + let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits); + + { + let scale = 1.0 / params.temperature; + for (i, &logit) in logits.iter().enumerate() { + let tid = i as TokenId; + + let val = if let Some(logit_override) = params.bias_tokens.get(tid) { + logit_override + } else if self.repetition_penalty_tokens().contains(&(i as TokenId)) { + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[i] < 0.0 { + logit * scale * params.repeat_penalty + } else { + logit * scale / params.repeat_penalty + } + } else { + logit * scale + }; + logits_id.push((val, tid)); + } + } + + // find the top K tokens + { + logits_id.partial_sort(params.top_k, |a, b| { + // Sort descending + b.0.total_cmp(&a.0) + }); + logits_id.truncate(params.top_k); + } + + let maxl = logits_id + .iter() + .map(|x| x.0) + .max_by(f32::total_cmp) + .unwrap(); + + // compute probs for the top K tokens + let mut probs: Vec = logits_id + .iter() + .copied() + .map(|(k, _)| (k - maxl).exp()) + .collect(); + let sum: f32 = probs.iter().copied().sum(); + + // Normalize the probs + for p in probs.iter_mut() { + *p /= sum; + } + + // Top p sampling + if params.top_p < 1.0 { + let mut cumsum = 0.0; + for i in 0..probs.len() { + cumsum += probs[i]; + if cumsum >= params.top_p { + probs.truncate(i + 1); + logits_id.truncate(i + 1); + break; + } + } + + cumsum = 1.0 / cumsum; + for p in probs.iter_mut() { + *p *= cumsum; + } + } + + let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); + let idx = dist.sample(rng); + + logits_id[idx].1 + } + + /// Obtains a serializable snapshot of the current inference status. This + /// can be used to cache the state of the model and store them into a file. + /// + /// # Safety + /// + /// This function provides raw access to the underlying memory owned by the + /// ggml context. While the provided `InferenceSnapshotRef` object is alive, + /// no other methods for this model object should be called. + pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> { + let memory_k = unsafe { + std::slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes()) + }; + let memory_v = unsafe { + std::slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes()) + }; + + InferenceSnapshotRef { + npast: self.n_past, + session_params: self.params, + tokens: self.tokens.clone(), + logits: self.last_logits.clone(), + memory_k, + memory_v, + } + } + + /// Creates an [InferenceSession] from a snapshot. + pub fn from_snapshot( + snapshot: InferenceSnapshot, + model: &Model, + ) -> Result { + let mut session = model.start_session(snapshot.session_params); + + if session.memory_k.nbytes() != snapshot.memory_k.len() + || session.memory_v.nbytes() != snapshot.memory_v.len() + { + return Err(SnapshotError::MemorySizeMismatch { + self_size: session.memory_k.nbytes() + session.memory_v.nbytes(), + input_size: snapshot.memory_k.len() + snapshot.memory_v.len(), + }); + } + + // SAFETY: We have exclusive access to Session, which means no one else + // should be touching the context's memory. We can write to it because + // we already checked the size. + unsafe { + session.memory_k.write_data(&snapshot.memory_k); + session.memory_v.write_data(&snapshot.memory_v); + } + + session.n_past = snapshot.npast; + session.tokens = snapshot.tokens; + session.last_logits = snapshot.last_logits; + + Ok(session) + } +} +impl InferenceSession { + pub(crate) fn new( + params: InferenceSessionParameters, + n_ctx: usize, + n_layer: usize, + n_embd: usize, + n_vocab: usize, + ) -> InferenceSession { + let ctx_size = { + let mut ctx_size = 0; + ctx_size += mulf!( + n_ctx, + n_layer, + n_embd, + ggml::type_sizef(params.memory_k_type.into()) + ); // memory_k + ctx_size += mulf!( + n_ctx, + n_layer, + n_embd, + ggml::type_sizef(params.memory_v_type.into()) + ); // memory_v + ctx_size += (5 + 10 * n_layer) * 256; // object overhead + ctx_size + }; + + let session_ctx = ggml::Context::init(ctx_size); + + // Initialize key + value memory tensors + let n_mem = n_layer * n_ctx; + let n_elements = n_embd * n_mem; + let memory_k = session_ctx.new_tensor_1d(params.memory_k_type.into(), n_elements); + let memory_v = session_ctx.new_tensor_1d(params.memory_v_type.into(), n_elements); + + InferenceSession { + _session_ctx: session_ctx, + memory_size: ctx_size, + params, + memory_k, + memory_v, + n_past: 0, + mem_per_token: 0, + tokens: vec![], + last_logits: vec![0.0; n_vocab], + scratch: scratch_buffers(), + } + } +} +impl InferenceSession { + fn repetition_penalty_tokens(&self) -> &[TokenId] { + &self.tokens[self + .tokens + .len() + .saturating_sub(self.params.repetition_penalty_last_n)..] + } +} +impl Clone for InferenceSession { + fn clone(&self) -> Self { + let context = ggml::Context::init(self.memory_size); + let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements()); + let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements()); + + Self { + _session_ctx: context, + memory_size: self.memory_size, + params: self.params, + memory_k, + memory_v, + n_past: self.n_past, + mem_per_token: self.mem_per_token, + tokens: self.tokens.clone(), + last_logits: self.last_logits.clone(), + scratch: scratch_buffers(), + } + } +} + +#[derive(Error, Debug)] +/// Errors encountered during the snapshot process. +pub enum SnapshotError { + /// Arbitrary I/O error. + #[error("I/O error while reading or writing snapshot")] + IO(#[from] std::io::Error), + /// Mismatch between the snapshotted memory and the in-memory memory. + #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")] + MemorySizeMismatch { + /// The size of the session memory in memory. + self_size: usize, + /// The size of the session memory in snapshot. + input_size: usize, + }, +} + +#[derive(serde::Serialize, Clone, PartialEq)] +/// A serializable snapshot of the inference process. +/// Can be created by calling [InferenceSession::get_snapshot]. +/// +/// If serializing, ensure that your serializer is binary-efficient. +/// This type contains a large array of bytes; traditional textual serializers +/// are likely to serialize this as an array of numbers at extreme cost. +// Keep in sync with [InferenceSession] and [InferenceSnapshot]. +pub struct InferenceSnapshotRef<'a> { + /// How many tokens have been stored in the memory so far. + pub npast: usize, + /// Parameters associated with the saved inference session. + pub session_params: InferenceSessionParameters, + /// All tokens generated by this inference session + pub tokens: Vec, + /// The vector of logits that was produced after the last inference + pub logits: Vec, + /// The contents of the 'key' memory tensor + #[serde(with = "serde_bytes")] + pub memory_k: &'a [u8], + /// The contents of the 'value' memory tensor + #[serde(with = "serde_bytes")] + pub memory_v: &'a [u8], +} +impl InferenceSnapshotRef<'_> { + /// Creates an owned [InferenceSnapshot] from this [InferenceSnapshotRef]. + /// + /// The [ToOwned] trait is not used due to its blanket implementation for all [Clone] types. + pub fn to_owned(&self) -> InferenceSnapshot { + InferenceSnapshot { + npast: self.npast, + session_params: self.session_params, + tokens: self.tokens.clone(), + last_logits: self.logits.clone(), + memory_k: self.memory_k.to_vec(), + memory_v: self.memory_v.to_vec(), + } + } +} + +/// A serializable snapshot of the inference process. Can be restored by calling +/// [Model::session_from_snapshot]. +#[derive(serde::Deserialize, Clone, PartialEq)] +// Keep in sync with [InferenceSession] and [InferenceSnapshotRef]. +pub struct InferenceSnapshot { + /// How many tokens have been stored in the memory so far. + pub npast: usize, + /// Parameters associated with the saved inference session. + pub session_params: InferenceSessionParameters, + /// All tokens generated by this inference session + pub tokens: Vec, + /// The vector of logits that was produced after the last inference + pub last_logits: Vec, + /// The contents of the 'key' memory tensor + #[serde(with = "serde_bytes")] + pub memory_k: Vec, + /// The contents of the 'value' memory tensor + #[serde(with = "serde_bytes")] + pub memory_v: Vec, +} + +#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +/// Parameters for an inference session. +pub struct InferenceSessionParameters { + /// The number of tokens to consider for the repetition penalty. + pub repetition_penalty_last_n: usize, + /// The type of the memory K tensor. + pub memory_k_type: ModelKVMemoryType, + /// The type of the memory V tensor. + pub memory_v_type: ModelKVMemoryType, +} +impl Default for InferenceSessionParameters { + fn default() -> Self { + Self { + repetition_penalty_last_n: 512, + memory_k_type: ModelKVMemoryType::Float32, + memory_v_type: ModelKVMemoryType::Float32, + } + } +} + +/// Statistics about the inference process. +#[derive(Debug, Clone, Copy)] +pub struct InferenceStats { + /// How long it took to feed the prompt. + pub feed_prompt_duration: std::time::Duration, + /// How many tokens the prompt was. + pub prompt_tokens: usize, + /// How long it took to predict new tokens. + pub predict_duration: std::time::Duration, + /// The number of predicted tokens. + pub predict_tokens: usize, +} +impl Default for InferenceStats { + fn default() -> Self { + Self { + feed_prompt_duration: std::time::Duration::from_secs(0), + prompt_tokens: 0, + predict_duration: std::time::Duration::from_secs(0), + predict_tokens: 0, + } + } +} +impl Display for InferenceStats { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms", + self.feed_prompt_duration.as_millis(), + self.prompt_tokens, + self.predict_duration.as_millis(), + self.predict_tokens, + (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64), + ) + } +} + +/// Allowed types for the model memory K/V tensors. +#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum ModelKVMemoryType { + /// 16-bit float. + Float16, + /// 32-bit float. + Float32, +} +impl From for ggml::Type { + fn from(value: ModelKVMemoryType) -> Self { + match value { + ModelKVMemoryType::Float16 => ggml::Type::F16, + ModelKVMemoryType::Float32 => ggml::Type::F32, + } + } +} + +fn scratch_buffers() -> [ggml::Buffer; 2] { + [ + ggml::Buffer::new(SCRATCH_SIZE), + ggml::Buffer::new(SCRATCH_SIZE), + ] +} diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 1a1159ec..51218cb8 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,256 +1,30 @@ #![deny(missing_docs)] //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. -use core::slice; -use std::{ - collections::HashMap, - fmt::Display, - io::{BufRead, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, - str::FromStr, - time, -}; - -use serde::Deserialize; use thiserror::Error; -use partial_sort::PartialSort; -use rand::{distributions::WeightedIndex, prelude::Distribution}; - -pub use ggml::Type as ElementType; - #[cfg(feature = "convert")] pub mod convert; +mod inference_session; +mod loader; +mod model; mod util; +mod vocabulary; + +pub use ggml::Type as ElementType; +pub use inference_session::{ + InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType, + SnapshotError, +}; +pub use loader::{LoadError, LoadProgress}; +pub use model::{Hyperparameters, Model}; +pub use util::TokenUtf8Buffer; +pub use vocabulary::{TokenBias, TokenId, Vocabulary}; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) -// The size of a scratch buffer used for inference. This is used for temporary -// storage of intermediate results during inference. -// -// The specific value was copied from `llama.cpp`. -const SCRATCH_SIZE: usize = 512 * 1024 * 1024; - -/// The hyperparameters of the model. -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)] -pub struct Hyperparameters { - n_vocab: usize, - n_ctx: usize, - n_embd: usize, - n_mult: usize, - n_head: usize, - n_layer: usize, - n_rot: usize, - f16_: u32, -} - -struct Layer { - attention_norm: ggml::Tensor, - - wq: ggml::Tensor, - wk: ggml::Tensor, - wv: ggml::Tensor, - wo: ggml::Tensor, - - // normalization - ffn_norm: ggml::Tensor, - - // ff - w1: ggml::Tensor, - w2: ggml::Tensor, - w3: ggml::Tensor, -} - -/// The weights for the LLaMA model. All the mutable state is split into a -/// separate struct `InferenceSession`. -pub struct Model { - hparams: Hyperparameters, - - tok_embeddings: ggml::Tensor, - - norm: ggml::Tensor, - output: ggml::Tensor, - - layers: Vec, - - tensors: HashMap, - - // Must be kept alive for the model - _context: ggml::Context, -} - -/// An inference session represents the state of the text generation. This holds -/// the full context window, as long as several additional parameters used -/// during sampling. -pub struct InferenceSession { - // Must be kept alive for the model - _session_ctx: ggml::Context, - - // Original size of the memory used to create this context. - memory_size: usize, - - // Parameters for the session. - params: InferenceSessionParameters, - - memory_k: ggml::Tensor, - memory_v: ggml::Tensor, - - /// How many tokens have been fed into the model's working memory so far. - n_past: usize, - - /// How much memory is required per token for the temporary context used - /// during inference. - mem_per_token: usize, - - /// All tokens generated by this inference session - tokens: Vec, - - /// The logits that were last predicted by the network. Zeroed out otherwise. - last_logits: Vec, - - /// Scratch buffers used during inference. - /// - /// The number of scratch buffers was copied from `llama.cpp`. - /// There is no specific reason for this number, but one is insufficient. - scratch: [ggml::Buffer; 2], -} -impl InferenceSession { - fn repetition_penalty_tokens(&self) -> &[TokenId] { - &self.tokens[self - .tokens - .len() - .saturating_sub(self.params.repetition_penalty_last_n)..] - } -} -impl Clone for InferenceSession { - fn clone(&self) -> Self { - let context = ggml::Context::init(self.memory_size); - let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements()); - let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements()); - - Self { - _session_ctx: context, - memory_size: self.memory_size, - params: self.params, - memory_k, - memory_v, - n_past: self.n_past, - mem_per_token: self.mem_per_token, - tokens: self.tokens.clone(), - last_logits: self.last_logits.clone(), - scratch: inference_session_scratch_buffers(), - } - } -} - -fn inference_session_scratch_buffers() -> [ggml::Buffer; 2] { - [ - ggml::Buffer::new(SCRATCH_SIZE), - ggml::Buffer::new(SCRATCH_SIZE), - ] -} - -#[derive(serde::Serialize, Clone, PartialEq)] -/// A serializable snapshot of the inference process. -/// Can be created by calling [InferenceSession::get_snapshot]. -/// -/// If serializing, ensure that your serializer is binary-efficient. -/// This type contains a large array of bytes; traditional textual serializers -/// are likely to serialize this as an array of numbers at extreme cost. -// Keep in sync with [InferenceSession] and [InferenceSnapshot]. -pub struct InferenceSnapshotRef<'a> { - /// How many tokens have been stored in the memory so far. - pub npast: usize, - /// Parameters associated with the saved inference session. - pub session_params: InferenceSessionParameters, - /// All tokens generated by this inference session - pub tokens: Vec, - /// The vector of logits that was produced after the last inference - pub logits: Vec, - /// The contents of the 'key' memory tensor - #[serde(with = "serde_bytes")] - pub memory_k: &'a [u8], - /// The contents of the 'value' memory tensor - #[serde(with = "serde_bytes")] - pub memory_v: &'a [u8], -} -impl InferenceSnapshotRef<'_> { - /// Creates an owned [InferenceSnapshot] from this [InferenceSnapshotRef]. - /// - /// The [ToOwned] trait is not used due to its blanket implementation for all [Clone] types. - pub fn to_owned(&self) -> InferenceSnapshot { - InferenceSnapshot { - npast: self.npast, - session_params: self.session_params, - tokens: self.tokens.clone(), - last_logits: self.logits.clone(), - memory_k: self.memory_k.to_vec(), - memory_v: self.memory_v.to_vec(), - } - } -} - -/// A serializable snapshot of the inference process. Can be restored by calling -/// [Model::session_from_snapshot]. -#[derive(serde::Deserialize, Clone, PartialEq)] -// Keep in sync with [InferenceSession] and [InferenceSnapshotRef]. -pub struct InferenceSnapshot { - /// How many tokens have been stored in the memory so far. - pub npast: usize, - /// Parameters associated with the saved inference session. - pub session_params: InferenceSessionParameters, - /// All tokens generated by this inference session - pub tokens: Vec, - /// The vector of logits that was produced after the last inference - pub last_logits: Vec, - /// The contents of the 'key' memory tensor - #[serde(with = "serde_bytes")] - pub memory_k: Vec, - /// The contents of the 'value' memory tensor - #[serde(with = "serde_bytes")] - pub memory_v: Vec, -} - -/// Allowed types for the model memory K/V tensors. -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum ModelKVMemoryType { - /// 16-bit float. - Float16, - /// 32-bit float. - Float32, -} -impl From for ggml::Type { - fn from(value: ModelKVMemoryType) -> Self { - match value { - ModelKVMemoryType::Float16 => ggml::Type::F16, - ModelKVMemoryType::Float32 => ggml::Type::F32, - } - } -} - -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -/// Parameters for an inference session. -pub struct InferenceSessionParameters { - /// The number of tokens to consider for the repetition penalty. - pub repetition_penalty_last_n: usize, - /// The type of the memory K tensor. - pub memory_k_type: ModelKVMemoryType, - /// The type of the memory V tensor. - pub memory_v_type: ModelKVMemoryType, -} - -impl Default for InferenceSessionParameters { - fn default() -> Self { - Self { - repetition_penalty_last_n: 512, - memory_k_type: ModelKVMemoryType::Float32, - memory_v_type: ModelKVMemoryType::Float32, - } - } -} - #[derive(Clone, Debug, PartialEq)] /// The parameters that drive text generation. pub struct InferenceParameters { @@ -274,7 +48,6 @@ pub struct InferenceParameters { /// Whether or not previous tokens should be played back in [InferenceSession::inference_with_prompt]. pub play_back_previous_tokens: bool, } - impl Default for InferenceParameters { fn default() -> Self { Self { @@ -290,273 +63,6 @@ impl Default for InferenceParameters { } } -/// Statistics about the inference process. -#[derive(Debug, Clone, Copy)] -pub struct InferenceStats { - /// How long it took to feed the prompt. - pub feed_prompt_duration: std::time::Duration, - /// How many tokens the prompt was. - pub prompt_tokens: usize, - /// How long it took to predict new tokens. - pub predict_duration: std::time::Duration, - /// The number of predicted tokens. - pub predict_tokens: usize, -} - -impl Default for InferenceStats { - fn default() -> Self { - Self { - feed_prompt_duration: std::time::Duration::from_secs(0), - prompt_tokens: 0, - predict_duration: std::time::Duration::from_secs(0), - predict_tokens: 0, - } - } -} - -impl Display for InferenceStats { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms", - self.feed_prompt_duration.as_millis(), - self.prompt_tokens, - self.predict_duration.as_millis(), - self.predict_tokens, - (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64), - ) - } -} - -type TokenId = i32; -type Token = Vec; -type TokenScore = f32; - -/// The vocabulary used by a model. -#[derive(Debug, Clone)] -pub struct Vocabulary { - /// Maps every integer (index) token id to its corresponding token - id_to_token: Vec, - - /// Maps every integer (index) token id to corresponding score - #[allow(dead_code)] - id_to_token_score: Vec, - - /// Maps a token to a token id - token_to_id: HashMap, - - /// The longest token in this vocabulary - max_token_length: usize, -} -impl Vocabulary { - fn token(&self, idx: usize) -> &[u8] { - &self.id_to_token[idx] - } -} - -#[derive(Default, Clone, Debug, PartialEq)] -/// A list of tokens to bias during the process of inferencing. -/// -/// When a biased token is encountered, the bias will be used -/// instead of the inferred logit during the sampling process. -/// -/// This can be used to disable the generation of responses -/// with specific tokens by setting their corresponding bias -/// to -1.0. -pub struct TokenBias(Vec<(TokenId, f32)>); - -impl TokenBias { - /// Create a [TokenBias] from an existing `Vec`. - pub fn new(mut v: Vec<(TokenId, f32)>) -> Self { - v.sort_by_cached_key(|(tid, _)| *tid); - v.dedup_by_key(|(tid, _)| *tid); - Self(v) - } - - /// Retrieves the bias for a given token, if available. - pub fn get(&self, tid: TokenId) -> Option { - self.0 - .binary_search_by_key(&tid, |(tid, _)| *tid) - .map(|idx| self.0[idx].1) - .ok() - } -} - -impl FromStr for TokenBias { - type Err = String; - - /// A comma separated list of token biases. The list should be in the format - /// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a - /// floating point number. - /// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1 - /// (start of document) and 2 (end of document) to -1.0 which effectively - /// disables the model from generating responses containing those token IDs. - fn from_str(s: &str) -> Result { - let x = s - .split(',') - .map(|kv| { - let (k, v) = kv - .trim() - .split_once('=') - .ok_or_else(|| "Missing '=' in bias item".to_owned())?; - let tid: TokenId = k - .trim() - .parse() - .map_err(|e: std::num::ParseIntError| e.to_string())?; - let bias: f32 = v - .trim() - .parse() - .map_err(|e: std::num::ParseFloatError| e.to_string())?; - Result::<_, String>::Ok((tid, bias)) - }) - .collect::>()?; - Ok(TokenBias::new(x)) - } -} - -impl std::fmt::Display for TokenBias { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } -} - -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] -pub enum LoadProgress<'a> { - /// The hyperparameters have been loaded from the model. - HyperparametersLoaded(&'a Hyperparameters), - /// The context has been created. - ContextSize { - /// The size of the context. - bytes: usize, - }, - /// A part of the model is being loaded. - PartLoading { - /// The path to the model part. - file: &'a Path, - /// The current part (0-indexed). - current_part: usize, - /// The number of total parts. - total_parts: usize, - }, - /// A tensor from the current part has been loaded. - PartTensorLoaded { - /// The path to the model part. - file: &'a Path, - /// The current tensor (0-indexed). - current_tensor: usize, - /// The number of total tensors. - tensor_count: usize, - }, - /// A model part has finished fully loading. - PartLoaded { - /// The path to the model part. - file: &'a Path, - /// The number of bytes in the part. - byte_size: usize, - /// The number of tensors in the part. - tensor_count: usize, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum LoadError { - #[error("could not open file {path:?}")] - /// A file failed to open. - OpenFileFailed { - /// The original error. - source: std::io::Error, - /// The path that failed. - path: PathBuf, - }, - #[error("no parent path for {path:?}")] - /// There is no parent path for a given path. - NoParentPath { - /// The path without a parent. - path: PathBuf, - }, - #[error("unable to read exactly {bytes} bytes")] - /// Reading exactly `bytes` from a file failed. - ReadExactFailed { - /// The original error. - source: std::io::Error, - /// The number of bytes that were attempted to be read. - bytes: usize, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - IO(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("invalid magic number for {path:?}")] - /// An invalid magic number was encountered during the loading process. - InvalidMagic { - /// The path that failed. - path: PathBuf, - }, - #[error("invalid file format version {value}")] - /// The version of the format is not supported by this version of `llama-rs`. - InvalidFormatVersion { - /// The version that was encountered. - value: u32, - }, - #[error("invalid value {ftype} for `f16` in hyperparameters")] - /// The `f16` hyperparameter had an invalid value. - HyperparametersF16Invalid { - /// The format type that was encountered. - ftype: u32, - }, - #[error("unknown tensor `{tensor_name}` in {path:?}")] - /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during - /// the model prelude. - UnknownTensor { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] - /// The tensor `tensor_name` did not match its expected size. - TensorWrongSize { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - /// The tensor `tensor_name` did not have the expected format type. - #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] - InvalidFtype { - /// The name of the tensor. - tensor_name: String, - /// The format type that was encountered. - ftype: u32, - /// The path that failed. - path: PathBuf, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the snapshot process. -pub enum SnapshotError { - /// Arbitrary I/O error. - #[error("I/O error while reading or writing snapshot")] - IO(#[from] std::io::Error), - /// Mismatch between the snapshotted memory and the in-memory memory. - #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")] - MemorySizeMismatch { - /// The size of the session memory in memory. - self_size: usize, - /// The size of the session memory in snapshot. - input_size: usize, - }, -} - #[derive(Error, Debug)] /// Errors encountered during the inferencep rocess. pub enum InferenceError { @@ -586,1284 +92,3 @@ pub struct EvaluateOutputRequest { /// Output shape is `n_batch * n_embd`. pub embeddings: Option>, } - -/// NOTE: The original code relies in promotion rules and automatic cast between -/// int to float. What we do instead is use this macro to convert every term of -/// the multiplication to f64, which should have enough precision bits to hold -/// the final value, then cast to usize. I have observed a discrepancy between -/// the ctx_size found using this code, and the one in llama.cpp. The number for -/// rust ends up being slightly lower, but no "out of memory" errors are -/// reported by ggml. -macro_rules! mulf { - ($term:expr, $($terms:expr),*) => { - usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap() - }; -} - -impl Model { - /// Load the model from `path` with `n_context_tokens` context tokens. - /// - /// The status of the loading process will be reported through `load_progress_callback`. - pub fn load( - path: impl AsRef, - n_context_tokens: usize, - mut load_progress_callback: impl FnMut(LoadProgress), - ) -> Result<(Model, Vocabulary), LoadError> { - use std::fs::File; - use std::io::BufReader; - - let main_path = path.as_ref(); - - let mut reader = - BufReader::new( - File::open(main_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: main_path.to_owned(), - })?, - ); - - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_bytes_with_len( - reader: &mut impl BufRead, - len: usize, - ) -> Result, LoadError> { - let mut bytes = vec![0u8; len]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: len, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) - } - - // Verify magic - let is_legacy_model: bool = match read_u32(&mut reader)? { - ggml::FILE_MAGIC => false, - ggml::FILE_MAGIC_UNVERSIONED => true, - _ => { - return Err(LoadError::InvalidMagic { - path: main_path.to_owned(), - }) - } - }; - - // Load format version - if !is_legacy_model { - #[allow(unused_variables)] - let version: u32 = match read_u32(&mut reader)? { - ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), - }; - } - - // ================= - // Load hyper params - // ================= - - // NOTE: Field order matters! Data is laid out in the file exactly - // in this order. - let hparams = Hyperparameters { - n_vocab: read_i32(&mut reader)?.try_into()?, - n_ctx: n_context_tokens, - n_embd: read_i32(&mut reader)?.try_into()?, - n_mult: read_i32(&mut reader)?.try_into()?, - n_head: read_i32(&mut reader)?.try_into()?, - n_layer: read_i32(&mut reader)?.try_into()?, - n_rot: read_i32(&mut reader)?.try_into()?, - f16_: read_i32(&mut reader)?.try_into()?, - }; - - let n_ff = - ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; - - load_progress_callback(LoadProgress::HyperparametersLoaded(&hparams)); - - // =============== - // Load vocabulary - // =============== - let vocab = { - let mut id_to_token = vec![]; - let mut id_to_token_score = vec![]; - let mut token_to_id = HashMap::new(); - let mut max_token_length = 0; - - for i in 0..hparams.n_vocab { - let len = read_i32(&mut reader)?; - let token = read_bytes_with_len(&mut reader, len as usize)?; - max_token_length = max_token_length.max(token.len()); - id_to_token.push(token.clone()); - token_to_id.insert(token, TokenId::try_from(i)?); - - // Token score, currently unused - if !is_legacy_model { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); - } - } else { - // Legacy model, set empty score - id_to_token_score.push(0.); - } - } - - Vocabulary { - id_to_token, - id_to_token_score, - token_to_id, - max_token_length, - } - }; - - // for the big tensors, we have the option to store the data in 16-bit - // floats or quantized in order to save memory and also to speed up the - // computation - let wtype = match hparams.f16_ { - 0 => ggml::Type::F32, - 1 => ggml::Type::F16, - 2 => ggml::Type::Q4_0, - 3 => ggml::Type::Q4_1, - invalid => return Err(LoadError::HyperparametersF16Invalid { ftype: invalid }), - }; - - let n_embd = hparams.n_embd; - let n_layer = hparams.n_layer; - let n_vocab = hparams.n_vocab; - - let ctx_size = { - // Use 64-bit math to prevent overflow. - let mut ctx_size: usize = 0; - - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings - - ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm - - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output - - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm - - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo - - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm - - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 - - ctx_size += (5 + 10 * n_layer) * 256; // object overhead - - load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size }); - - ctx_size - }; - - // Initialize the context - let context = ggml::Context::init(ctx_size); - - let model = { - let mut tensors = HashMap::new(); - - let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); - let norm = context.new_tensor_1d(ggml::Type::F32, n_embd); - let output = context.new_tensor_2d(wtype, n_embd, n_vocab); - - tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); - tensors.insert("norm.weight".to_owned(), norm.share()); - tensors.insert("output.weight".to_owned(), output.share()); - - let mut layers = Vec::new(); - for i in 0..n_layer { - let layer = Layer { - attention_norm: context.new_tensor_1d(ggml::Type::F32, n_embd), - wq: context.new_tensor_2d(wtype, n_embd, n_embd), - wk: context.new_tensor_2d(wtype, n_embd, n_embd), - wv: context.new_tensor_2d(wtype, n_embd, n_embd), - wo: context.new_tensor_2d(wtype, n_embd, n_embd), - ffn_norm: context.new_tensor_1d(ggml::Type::F32, n_embd), - w1: context.new_tensor_2d(wtype, n_embd, n_ff), - w2: context.new_tensor_2d(wtype, n_ff, n_embd), - w3: context.new_tensor_2d(wtype, n_embd, n_ff), - }; - - tensors.insert( - format!("layers.{i}.attention_norm.weight"), - layer.attention_norm.share(), - ); - - tensors.insert(format!("layers.{i}.attention.wq.weight"), layer.wq.share()); - tensors.insert(format!("layers.{i}.attention.wk.weight"), layer.wk.share()); - tensors.insert(format!("layers.{i}.attention.wv.weight"), layer.wv.share()); - tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); - - tensors.insert( - format!("layers.{i}.ffn_norm.weight"), - layer.ffn_norm.share(), - ); - - tensors.insert( - format!("layers.{i}.feed_forward.w1.weight"), - layer.w1.share(), - ); - tensors.insert( - format!("layers.{i}.feed_forward.w2.weight"), - layer.w2.share(), - ); - tensors.insert( - format!("layers.{i}.feed_forward.w3.weight"), - layer.w3.share(), - ); - - layers.push(layer); - } - - Model { - hparams, - tok_embeddings, - norm, - output, - layers, - tensors, - _context: context, - } - }; - - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); - - let paths = util::find_all_model_files(main_path)?; - let n_parts = paths.len(); - - for (i, part_path) in paths.into_iter().enumerate() { - let part_id = i; - - load_progress_callback(LoadProgress::PartLoading { - file: &part_path, - current_part: i, - total_parts: n_parts, - }); - - let mut part_reader = BufReader::new(File::open(&part_path)?); - - // Skip metadata - part_reader.seek(SeekFrom::Start(file_offset))?; - - let mut total_size = 0; - let mut n_tensors = 0; - - // Load weights - loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { - break; - } - - let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; - let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; - - if n_dims == 1 || n_parts == 1 { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if part_id == 0 { - // SAFETY: yolo, same as original code - let slice = unsafe { - let data = tensor.data(); - std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) - }; - part_reader.read_exact(slice)?; - } else { - part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; - } - - total_size += tensor.nbytes(); - } else { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) - != tensor.nbytes() / n_parts - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if split_type == 0 { - let np0 = ne[0]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - assert_eq!(row_size, tensor.get_nb()[1]); - - for i1 in 0..ne[1] { - let offset_row = i1 as usize * row_size; - let offset = offset_row - + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset); - let slice = std::slice::from_raw_parts_mut( - ptr as *mut u8, - row_size / n_parts, - ); - part_reader.read_exact(slice)?; - } - } - } else { - let np1 = ne[1]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - for i1 in 0..ne[1] { - let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset_row); - let slice = - std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); - part_reader.read_exact(slice)?; - } - } - } - - total_size += tensor.nbytes() / n_parts; - } - - n_tensors += 1; - load_progress_callback(LoadProgress::PartTensorLoaded { - file: &part_path, - current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors.len(), - }); - } - - load_progress_callback(LoadProgress::PartLoaded { - file: &part_path, - byte_size: total_size, - tensor_count: n_tensors.try_into()?, - }); - } - - Ok((model, vocab)) - } - - /// Starts a new `InferenceSession` for this model. - pub fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession { - let Hyperparameters { - n_ctx, - n_embd, - n_layer, - n_vocab, - .. - } = self.hparams; - - let ctx_size = { - let mut ctx_size = 0; - ctx_size += mulf!( - n_ctx, - n_layer, - n_embd, - ggml::type_sizef(params.memory_k_type.into()) - ); // memory_k - ctx_size += mulf!( - n_ctx, - n_layer, - n_embd, - ggml::type_sizef(params.memory_v_type.into()) - ); // memory_v - ctx_size += (5 + 10 * n_layer) * 256; // object overhead - ctx_size - }; - - let session_ctx = ggml::Context::init(ctx_size); - - // Initialize key + value memory tensors - let n_mem = n_layer * n_ctx; - let n_elements = n_embd * n_mem; - let memory_k = session_ctx.new_tensor_1d(params.memory_k_type.into(), n_elements); - let memory_v = session_ctx.new_tensor_1d(params.memory_v_type.into(), n_elements); - - InferenceSession { - _session_ctx: session_ctx, - memory_size: ctx_size, - params, - memory_k, - memory_v, - n_past: 0, - mem_per_token: 0, - tokens: vec![], - last_logits: vec![0.0; n_vocab], - scratch: inference_session_scratch_buffers(), - } - } - - /// Evaluates the transformer. - /// - /// The provided `output_request` struct lets you specify which additional - /// data you are interested in fetching from the transformer. Setting a - /// field to a `Some` value will clear and fill the provided vector with - /// data. The provided vector will be resized to the exact output size. - pub fn evaluate( - &self, - session: &mut InferenceSession, - params: &InferenceParameters, - input_tokens: &[TokenId], - output_request: &mut EvaluateOutputRequest, - ) { - let n = input_tokens.len(); - let n_past = session.n_past; - let n_threads = params.n_threads; - - let memk_elsize = session.memory_k.element_size(); - let memv_elsize = session.memory_v.element_size(); - - let Hyperparameters { - n_vocab, - n_ctx, - n_embd, - n_mult: _, - n_head, - n_layer, - n_rot, - f16_: _, - } = self.hparams; - - // For the first run, we need to guess a maximum buffer size so we can measure - // the actual memory consumption of the temporary ggml context. - // - // These numbers are from `llama.cpp`, and could potentially be more efficient. - let mut buf_size = { - let buf_size_mb = if n_layer >= 80 { - 1536 - } else if n_layer >= 60 { - 1280 - } else { - 1024 - }; - buf_size_mb * 1024 * 1024 - }; - if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { - // add 10% to account for ggml object overhead - buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; - }; - let ctx0 = ggml::Context::init(buf_size); - - let mut gf = ggml::ComputationGraph::new(n_threads); - - let embd = ctx0.new_tensor_1d(ggml::Type::I32, n); - unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; - - let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); - - for il in 0..n_layer { - let input_self_attention = input_layer.share(); - let mut current: ggml::Tensor; - - ctx0.use_scratch(Some(&mut session.scratch[0])); - - // norm - { - current = ctx0.op_rms_norm(&input_layer); - - // cur = attention_norm * cur - current = ctx0.op_mul( - &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), - ¤t, - ); - } - - // self-attention - { - // compute Q and K and RoPE them - let q_current = ctx0.op_rope( - &ctx0.op_reshape_3d( - &ctx0.op_mul_mat(&self.layers[il].wq, ¤t), - n_embd / n_head, - n_head, - n, - ), - n_past, - n_rot, - 0, - ); - let k_current = ctx0.op_rope( - &ctx0.op_reshape_3d( - &ctx0.op_mul_mat(&self.layers[il].wk, ¤t), - n_embd / n_head, - n_head, - n, - ), - n_past, - n_rot, - 0, - ); - - // store key and value to memory - { - // compute the transposed [N, n_embd] V matrix - let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d( - &ctx0.op_mul_mat(&self.layers[il].wv, ¤t), - n_embd, - n, - )); - - let k = ctx0.op_view_1d( - &session.memory_k, - n * n_embd, - (memk_elsize * n_embd) * (il * n_ctx + n_past), - ); - - let v = ctx0.op_view_2d( - &session.memory_v, - n, - n_embd, - n_ctx * memv_elsize, - (il * n_ctx) * memv_elsize * n_embd + n_past * memv_elsize, - ); - - // important: storing RoPE-ed version of K in the KV cache! - gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); - } - - let q = ctx0.op_permute(&q_current, 0, 2, 1, 3); - - let k = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - &session.memory_k, - (n_past + n) * n_embd, - il * n_ctx * memk_elsize * n_embd, - ), - n_embd / n_head, - n_head, - n_past + n, - ), - 0, - 2, - 1, - 3, - ); - - // K * Q - let k_q = ctx0.op_mul_mat(&k, &q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - let k_q_scaled = ctx0.op_scale( - &k_q, - &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - // KQ_masked = mask_past(KQ_scaled) - let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled, n_past); - - // KQ = soft_max(KQ_masked) - let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); - - // split cached V into n_head heads - let v = ctx0.op_view_3d( - &session.memory_v, - n_past + n, - n_embd / n_head, - n_head, - n_ctx * memv_elsize, - n_ctx * memv_elsize * n_embd / n_head, - il * n_ctx * memv_elsize * n_embd, - ); - - let k_q_v = ctx0.op_mul_mat(&v, &k_q_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - current = ctx0.op_cpy( - &k_q_v_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n), - ); - - // projection (no bias) - current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - } - - ctx0.use_scratch(Some(&mut session.scratch[1])); - - let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); - - // feed-forward network - { - // norm - { - current = ctx0.op_rms_norm(&input_feed_forward); - - // cur = ffn_norm*cur - current = ctx0.op_mul( - &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t), - ¤t, - ); - } - - let tmp = ctx0.op_mul_mat(&self.layers[il].w3, ¤t); - - current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); - - // SILU activation - current = ctx0.op_silu(¤t); - - current = ctx0.op_mul(¤t, &tmp); - - current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); - } - - current = ctx0.op_add(¤t, &input_feed_forward); - - // input for next layer - input_layer = current; - } - - ctx0.use_scratch(Some(&mut session.scratch[0])); - - // Used at the end to optionally extract the embeddings. - let embeddings_tensor; - - // norm - { - input_layer = ctx0.op_rms_norm(&input_layer); - - // inpL = norm*inpL - input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer); - embeddings_tensor = input_layer.share(); - } - - // lm_head - { - input_layer = ctx0.op_mul_mat(&self.output, &input_layer); - } - - ctx0.use_scratch(None); - - // logits -> probs - // inpL = ctx0.op_soft_max(&inpL); - - // run the computation - gf.build_forward_expand(&input_layer); - ctx0.graph_compute(&mut gf); - - // return result for just the last token - // SAFETY: yolo - assert_eq!(session.last_logits.len(), n_vocab); - unsafe { - input_layer.read_data( - n_vocab * (n - 1) * std::mem::size_of::(), - bytemuck::cast_slice_mut(&mut session.last_logits), - ) - }; - - // Extract logits - if let Some(all_logits) = &mut output_request.all_logits { - all_logits.resize(n_vocab * n, 0.0); - // SAFETY: Tensor data can be read (properly aligned, initialized, - // data will not be mutated or otherwise aliased during the copy), - // and we're not reading past the end of the tensor data. - assert_eq!(input_layer.nelements(), n_vocab * n); - unsafe { - input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); - } - } - - // Extract embeddings - if let Some(embeddings) = &mut output_request.embeddings { - embeddings.resize(n_embd * n, 0.0); - // SAFETY: Same rationale as for the "Extract logits" section applies. - assert_eq!(embeddings_tensor.nelements(), n_embd * n); - unsafe { - embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings)); - } - } - - // Adjust the required memory per token if we didn't know that already - if session.mem_per_token == 0 { - session.mem_per_token = ctx0.used_mem() / n; - } - - // Adjust n_past to new length. - session.n_past += input_tokens.len(); - } - - /// Hydrates a previously obtained InferenceSnapshot for this model. - pub fn session_from_snapshot( - &self, - snapshot: InferenceSnapshot, - ) -> Result { - let mut session = self.start_session(snapshot.session_params); - - if session.memory_k.nbytes() != snapshot.memory_k.len() - || session.memory_v.nbytes() != snapshot.memory_v.len() - { - return Err(SnapshotError::MemorySizeMismatch { - self_size: session.memory_k.nbytes() + session.memory_v.nbytes(), - input_size: snapshot.memory_k.len() + snapshot.memory_v.len(), - }); - } - - // SAFETY: We have exclusive access to Session, which means no one else - // should be touching the context's memory. We can write to it because - // we already checked the size. - unsafe { - session.memory_k.write_data(&snapshot.memory_k); - session.memory_v.write_data(&snapshot.memory_v); - } - - session.n_past = snapshot.npast; - session.tokens = snapshot.tokens; - session.last_logits = snapshot.last_logits; - - Ok(session) - } -} - -impl InferenceSession { - /// Feed a prompt to the model for this session. - pub fn feed_prompt( - &mut self, - model: &Model, - vocab: &Vocabulary, - params: &InferenceParameters, - prompt: &str, - mut callback: impl FnMut(&[u8]) -> Result<(), E>, - ) -> Result<(), InferenceError> { - let beginning_of_sentence = self.n_past == 0; - let prompt_tokens: Vec = vocab - .tokenize(prompt, beginning_of_sentence)? - .iter() - .map(|(_, tok)| *tok) - .collect(); - - if self.n_past + prompt_tokens.len() >= model.hparams.n_ctx { - return Err(InferenceError::ContextFull); - } - - for batch in prompt_tokens.chunks(params.n_batch) { - model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default()); - for &tk in batch { - // NOTE: No string ever tokenizes to the end of sentence. So we - // can just return the id here. - if let Err(e) = callback(vocab.token(tk as usize)) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - - // Update the tokens for this session - self.tokens.push(tk); - } - } - - Ok(()) - } - - /// Infer the next token for this session. - pub fn infer_next_token<'v>( - &mut self, - model: &Model, - vocab: &'v Vocabulary, - params: &InferenceParameters, - rng: &mut impl rand::Rng, - ) -> Result<&'v [u8], InferenceError> { - if self.n_past + 1 >= model.hparams.n_ctx { - return Err(InferenceError::ContextFull); - } - - // First, sample the next token, using the stored last_logits; - let next_token = self.sample_top_p_top_k(params, rng); - - // Update the tokens for this session - self.tokens.push(next_token); - - // Then, evaluate the network again to compute the new last_logits - model.evaluate( - self, - params, - &[next_token], - &mut EvaluateOutputRequest::default(), - ); - - // Return the next token - if next_token as TokenId == EOT_TOKEN_ID { - Err(InferenceError::EndOfText) - } else { - Ok(vocab.token(next_token as usize)) - } - } - - // todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe? - /// Helper function to run inference with this session and the given model and vocabulary. - /// The `callback` is called with each new token until inference is complete. - /// - /// If `params.play_back_previous_tokens` is specified, this will "play back" all existing tokens in the session. - #[allow(clippy::too_many_arguments)] - pub fn inference_with_prompt( - &mut self, - model: &Model, - vocab: &Vocabulary, - params: &InferenceParameters, - prompt: &str, - maximum_token_count: Option, - rng: &mut impl rand::Rng, - mut callback: impl FnMut(&str) -> Result<(), E>, - ) -> Result { - let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX); - if params.play_back_previous_tokens { - // "Play back" the existing tokens, so that loading from an inference snapshot works - // as expected. - let mut token_utf8_buf = TokenUtf8Buffer::new(); - for token_id in &self.tokens { - // Buffer the token until it's valid UTF-8, then call the callback. - if let Some(tokens) = token_utf8_buf.push(vocab.token(*token_id as usize)) { - if let Err(e) = callback(&tokens) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - } - } - } - - let mut stats = InferenceStats::default(); - - let start_at = time::SystemTime::now(); - - // Feed the initial prompt through the transformer, to update its - // context window with new data. - self.feed_prompt( - model, - vocab, - params, - prompt, - TokenUtf8Buffer::adapt_callback(&mut callback), - )?; - stats.feed_prompt_duration = start_at.elapsed().unwrap(); - stats.prompt_tokens = self.n_past; - - // After the prompt is consumed, sample tokens by repeatedly calling - // `infer_next_token`. We generate tokens until the model returns an - // EndOfText token, or we run out of space in the context window, - // or we reach the specified limit. - let mut tokens_processed = 0; - let mut token_utf8_buf = TokenUtf8Buffer::new(); - while tokens_processed < maximum_token_count { - let token = match self.infer_next_token(model, vocab, params, rng) { - Ok(token) => token, - Err(InferenceError::EndOfText) => break, - Err(e) => return Err(e), - }; - - // Buffer the token until it's valid UTF-8, then call the callback. - if let Some(tokens) = token_utf8_buf.push(token) { - if let Err(e) = callback(&tokens) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - } - - tokens_processed += 1; - } - stats.predict_duration = start_at.elapsed().unwrap(); - stats.predict_tokens = self.n_past; - - Ok(stats) - } - - /// Sample a token using Top-P/Top-K sampling and the last logits from this session. - pub fn sample_top_p_top_k( - &self, - params: &InferenceParameters, - rng: &mut impl rand::Rng, - ) -> TokenId { - let logits = &self.last_logits; - let n_logits = logits.len(); - let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits); - - { - let scale = 1.0 / params.temperature; - for (i, &logit) in logits.iter().enumerate() { - let tid = i as TokenId; - - let val = if let Some(logit_override) = params.bias_tokens.get(tid) { - logit_override - } else if self.repetition_penalty_tokens().contains(&(i as TokenId)) { - // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if logits[i] < 0.0 { - logit * scale * params.repeat_penalty - } else { - logit * scale / params.repeat_penalty - } - } else { - logit * scale - }; - logits_id.push((val, tid)); - } - } - - // find the top K tokens - { - logits_id.partial_sort(params.top_k, |a, b| { - // Sort descending - b.0.total_cmp(&a.0) - }); - logits_id.truncate(params.top_k); - } - - let maxl = logits_id - .iter() - .map(|x| x.0) - .max_by(f32::total_cmp) - .unwrap(); - - // compute probs for the top K tokens - let mut probs: Vec = logits_id - .iter() - .copied() - .map(|(k, _)| (k - maxl).exp()) - .collect(); - let sum: f32 = probs.iter().copied().sum(); - - // Normalize the probs - for p in probs.iter_mut() { - *p /= sum; - } - - // Top p sampling - if params.top_p < 1.0 { - let mut cumsum = 0.0; - for i in 0..probs.len() { - cumsum += probs[i]; - if cumsum >= params.top_p { - probs.truncate(i + 1); - logits_id.truncate(i + 1); - break; - } - } - - cumsum = 1.0 / cumsum; - for p in probs.iter_mut() { - *p *= cumsum; - } - } - - let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); - let idx = dist.sample(rng); - - logits_id[idx].1 - } - - /// Obtains a serializable snapshot of the current inference status. This - /// can be used to cache the state of the model and store them into a file. - /// - /// # Safety - /// - /// This function provides raw access to the underlying memory owned by the - /// ggml context. While the provided `InferenceSnapshotRef` object is alive, - /// no other methods for this model object should be called. - pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> { - let memory_k = unsafe { - slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes()) - }; - let memory_v = unsafe { - slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes()) - }; - - InferenceSnapshotRef { - npast: self.n_past, - session_params: self.params, - tokens: self.tokens.clone(), - logits: self.last_logits.clone(), - memory_k, - memory_v, - } - } -} - -impl Vocabulary { - // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece - /// Tokenize a `text` with this vocabulary. - /// - /// `bos` controls whether a beginning-of-string token should be inserted. - pub fn tokenize<'a>( - &'a self, - text: &str, - bos: bool, - ) -> Result, InferenceError> { - let len = text.len(); - - let mut score = vec![0usize; len + 1]; - let mut prev = vec![TokenId::default(); len + 1]; - - for i in 0..len { - let max_len = (len - i).min(self.max_token_length); - for sub_len in 1..=max_len { - let sub = &text.as_bytes()[i..i + sub_len]; - let token = self.token_to_id.get(sub); - - if let Some(token) = token { - let token_score = sub.len() * sub.len(); - let local_score = score[i] + token_score; - let next = i + sub_len; - - if score[next] < local_score { - score[next] = local_score; - prev[next] = *token; - } - } - } - } - - // Backward pass - let mut res = vec![]; - let mut i = len; - while i > 0 { - let token_id = prev[i]; - if token_id == 0 { - return Err(InferenceError::TokenizationFailed); - } - let token = self.id_to_token[token_id as usize].as_slice(); - res.push((token, token_id)); - i -= token.len(); - } - - if bos { - // TODO: replace with vocab.bos - res.push((&[], 1)); - } - - // Pieces are in reverse order so correct that - res.reverse(); - - Ok(res) - } -} - -/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. -/// -/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 -/// from multiple tokens. This helps alleviate that issue. -#[derive(Clone, PartialEq, Default)] -pub struct TokenUtf8Buffer(Vec); -impl TokenUtf8Buffer { - /// Create a new buffer. - pub const fn new() -> Self { - Self(vec![]) - } - - /// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text, - /// it is returned and the buffer is cleared for next use. - pub fn push(&mut self, token: &[u8]) -> Option { - self.0.extend_from_slice(token); - match std::str::from_utf8(&self.0) { - Ok(s) => { - let out = s.to_owned(); - self.0 = vec![]; - Some(out) - } - Err(..) => { - for i in 1..self.0.len() { - let slice = &self.0[i..]; - if slice.is_empty() { - break; - } - - if let Ok(s) = std::str::from_utf8(slice) { - let out = s.to_owned(); - self.0 = vec![]; - return Some(out); - } - } - None - } - } - } - - /// Adapt a `&str` callback so that it can be used in a `&[u8]` context. - fn adapt_callback<'a, E: std::error::Error + 'static>( - mut callback: impl FnMut(&str) -> Result<(), E> + 'a, - ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a { - let mut buffer = Self::new(); - move |token| match buffer.push(token) { - Some(tokens) => callback(&tokens), - None => Ok(()), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_valid_utf8() { - let mut buffer = TokenUtf8Buffer::new(); - assert_eq!(buffer.push(b"hello").as_deref(), Some("hello")); - assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€")); - } - - #[test] - fn test_partial_utf8() { - let mut buffer = TokenUtf8Buffer::new(); - assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); - assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); - } - - #[test] - fn test_invalid_prelude_for_valid_utf8() { - let mut buffer = TokenUtf8Buffer::new(); - assert_eq!(buffer.push(&[0xD8]).as_deref(), None); - assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); - assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); - } -} diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs new file mode 100644 index 00000000..f99344dd --- /dev/null +++ b/llama-rs/src/loader.rs @@ -0,0 +1,554 @@ +use std::{ + collections::HashMap, + io::{BufRead, Read, Seek, SeekFrom}, + path::{Path, PathBuf}, +}; + +use thiserror::Error; + +use crate::{ + util::{self, mulf}, + vocabulary::TokenId, + Hyperparameters, Model, Vocabulary, +}; + +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub enum LoadProgress<'a> { + /// The hyperparameters have been loaded from the model. + HyperparametersLoaded(&'a Hyperparameters), + /// The context has been created. + ContextSize { + /// The size of the context. + bytes: usize, + }, + /// A part of the model is being loaded. + PartLoading { + /// The path to the model part. + file: &'a Path, + /// The current part (0-indexed). + current_part: usize, + /// The number of total parts. + total_parts: usize, + }, + /// A tensor from the current part has been loaded. + PartTensorLoaded { + /// The path to the model part. + file: &'a Path, + /// The current tensor (0-indexed). + current_tensor: usize, + /// The number of total tensors. + tensor_count: usize, + }, + /// A model part has finished fully loading. + PartLoaded { + /// The path to the model part. + file: &'a Path, + /// The number of bytes in the part. + byte_size: usize, + /// The number of tensors in the part. + tensor_count: usize, + }, +} + +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum LoadError { + #[error("could not open file {path:?}")] + /// A file failed to open. + OpenFileFailed { + /// The original error. + source: std::io::Error, + /// The path that failed. + path: PathBuf, + }, + #[error("no parent path for {path:?}")] + /// There is no parent path for a given path. + NoParentPath { + /// The path without a parent. + path: PathBuf, + }, + #[error("unable to read exactly {bytes} bytes")] + /// Reading exactly `bytes` from a file failed. + ReadExactFailed { + /// The original error. + source: std::io::Error, + /// The number of bytes that were attempted to be read. + bytes: usize, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + IO(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("invalid magic number for {path:?}")] + /// An invalid magic number was encountered during the loading process. + InvalidMagic { + /// The path that failed. + path: PathBuf, + }, + #[error("invalid file format version {value}")] + /// The version of the format is not supported by this version of `llama-rs`. + InvalidFormatVersion { + /// The version that was encountered. + value: u32, + }, + #[error("invalid value {ftype} for `f16` in hyperparameters")] + /// The `f16` hyperparameter had an invalid value. + HyperparametersF16Invalid { + /// The format type that was encountered. + ftype: u32, + }, + #[error("unknown tensor `{tensor_name}` in {path:?}")] + /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during + /// the model prelude. + UnknownTensor { + /// The name of the tensor. + tensor_name: String, + /// The path that failed. + path: PathBuf, + }, + #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] + /// The tensor `tensor_name` did not match its expected size. + TensorWrongSize { + /// The name of the tensor. + tensor_name: String, + /// The path that failed. + path: PathBuf, + }, + /// The tensor `tensor_name` did not have the expected format type. + #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] + InvalidFtype { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: u32, + /// The path that failed. + path: PathBuf, + }, +} + +pub fn load( + path: impl AsRef, + n_context_tokens: usize, + mut load_progress_callback: impl FnMut(LoadProgress), +) -> Result { + use std::fs::File; + use std::io::BufReader; + + let main_path = path.as_ref(); + + let mut reader = + BufReader::new( + File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?, + ); + + // Verify magic + let is_legacy_model: bool = match read_u32(&mut reader)? { + ggml::FILE_MAGIC => false, + ggml::FILE_MAGIC_UNVERSIONED => true, + _ => { + return Err(LoadError::InvalidMagic { + path: main_path.to_owned(), + }) + } + }; + + // Load format version + if !is_legacy_model { + #[allow(unused_variables)] + let version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; + } + + // ================= + // Load hyper params + // ================= + + // NOTE: Field order matters! Data is laid out in the file exactly + // in this order. + let hparams = Hyperparameters { + n_vocab: read_i32(&mut reader)?.try_into()?, + n_ctx: n_context_tokens, + n_embd: read_i32(&mut reader)?.try_into()?, + n_mult: read_i32(&mut reader)?.try_into()?, + n_head: read_i32(&mut reader)?.try_into()?, + n_layer: read_i32(&mut reader)?.try_into()?, + n_rot: read_i32(&mut reader)?.try_into()?, + f16_: read_i32(&mut reader)?.try_into()?, + }; + + let n_ff = + ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; + + load_progress_callback(LoadProgress::HyperparametersLoaded(&hparams)); + + // =============== + // Load vocabulary + // =============== + let vocabulary = { + let mut id_to_token = vec![]; + let mut id_to_token_score = vec![]; + let mut token_to_id = HashMap::new(); + let mut max_token_length = 0; + + for i in 0..hparams.n_vocab { + let len = read_i32(&mut reader)?; + let token = read_bytes_with_len(&mut reader, len as usize)?; + max_token_length = max_token_length.max(token.len()); + id_to_token.push(token.clone()); + token_to_id.insert(token, TokenId::try_from(i)?); + + // Token score, currently unused + if !is_legacy_model { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } else { + // Legacy model, set empty score + id_to_token_score.push(0.); + } + } + + Vocabulary { + id_to_token, + id_to_token_score, + token_to_id, + max_token_length, + } + }; + + // for the big tensors, we have the option to store the data in 16-bit + // floats or quantized in order to save memory and also to speed up the + // computation + let wtype = match hparams.f16_ { + 0 => ggml::Type::F32, + 1 => ggml::Type::F16, + 2 => ggml::Type::Q4_0, + 3 => ggml::Type::Q4_1, + invalid => return Err(LoadError::HyperparametersF16Invalid { ftype: invalid }), + }; + + let n_embd = hparams.n_embd; + let n_layer = hparams.n_layer; + let n_vocab = hparams.n_vocab; + + let ctx_size = { + // Use 64-bit math to prevent overflow. + let mut ctx_size: usize = 0; + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm + + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm + + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 + + ctx_size += (5 + 10 * n_layer) * 256; // object overhead + + load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size }); + + ctx_size + }; + + // Initialize the context + let context = ggml::Context::init(ctx_size); + + let model = Model::new(context, hparams, vocabulary, n_ff, wtype); + + // Close the file, but keep its offset. That way we know how to skip the + // metadata when loading the parts. + let file_offset = reader.stream_position()?; + drop(reader); + + let paths = util::find_all_model_files(main_path)?; + let n_parts = paths.len(); + + for (i, part_path) in paths.into_iter().enumerate() { + let part_id = i; + + load_progress_callback(LoadProgress::PartLoading { + file: &part_path, + current_part: i, + total_parts: n_parts, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + // NOTE: Implementation from #![feature(buf_read_has_data_left)] + let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; + + if is_eof { + break; + } + + let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_u32(&mut part_reader)?; + + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(&mut part_reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + + let tensor_name = read_string(&mut part_reader, length as usize)?; + + let Some(tensor) = model.tensors().get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); + }; + + // split_type = 0: split by columns + // split_type = 1: split by rows + // + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] + || tensor.get_ne()[1] != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.get_ne()[0] != ne[0] + || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + let bpe = match ftype { + 0 => ggml::type_size(ggml::Type::F32), + 1 => ggml::type_size(ggml::Type::F16), + 2 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::Type::Q4_0) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::Type::Q4_1) + } + _ => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: part_path, + }) + } + }; + + if n_dims == 1 || n_parts == 1 { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + let data = tensor.data(); + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) + != tensor.nbytes() / n_parts + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size / n_parts); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts; + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: &part_path, + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors().len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: &part_path, + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + } + + Ok(model) +} + +pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; + Ok(bytes) +} + +pub fn read_bytes_with_len(reader: &mut impl BufRead, len: usize) -> Result, LoadError> { + let mut bytes = vec![0u8; len]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: len, + })?; + Ok(bytes) +} + +pub fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +/// Helper function. Reads a string from the buffer and returns it. +pub fn read_string(reader: &mut impl BufRead, len: usize) -> Result { + Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) +} diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs new file mode 100644 index 00000000..370e62df --- /dev/null +++ b/llama-rs/src/model.rs @@ -0,0 +1,469 @@ +use std::{collections::HashMap, path::Path}; + +use serde::Deserialize; + +use crate::{ + loader, vocabulary::TokenId, EvaluateOutputRequest, InferenceParameters, InferenceSession, + InferenceSessionParameters, LoadError, LoadProgress, Vocabulary, +}; + +/// The weights for the LLaMA model. All the mutable state is split into a +/// separate struct `InferenceSession`. +pub struct Model { + pub(crate) hparams: Hyperparameters, + + vocabulary: Vocabulary, + + tok_embeddings: ggml::Tensor, + + norm: ggml::Tensor, + output: ggml::Tensor, + + layers: Vec, + + tensors: HashMap, + + // Must be kept alive for the model + _context: ggml::Context, +} +impl Model { + pub(crate) fn new( + context: ggml::Context, + hparams: Hyperparameters, + vocabulary: Vocabulary, + n_ff: usize, + wtype: ggml::Type, + ) -> Model { + let n_embd = hparams.n_embd; + let n_layer = hparams.n_layer; + let n_vocab = hparams.n_vocab; + + let mut tensors = HashMap::new(); + + let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); + let norm = context.new_tensor_1d(ggml::Type::F32, n_embd); + let output = context.new_tensor_2d(wtype, n_embd, n_vocab); + + tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); + tensors.insert("norm.weight".to_owned(), norm.share()); + tensors.insert("output.weight".to_owned(), output.share()); + + let mut layers = Vec::new(); + for i in 0..n_layer { + let layer = Layer { + attention_norm: context.new_tensor_1d(ggml::Type::F32, n_embd), + wq: context.new_tensor_2d(wtype, n_embd, n_embd), + wk: context.new_tensor_2d(wtype, n_embd, n_embd), + wv: context.new_tensor_2d(wtype, n_embd, n_embd), + wo: context.new_tensor_2d(wtype, n_embd, n_embd), + ffn_norm: context.new_tensor_1d(ggml::Type::F32, n_embd), + w1: context.new_tensor_2d(wtype, n_embd, n_ff), + w2: context.new_tensor_2d(wtype, n_ff, n_embd), + w3: context.new_tensor_2d(wtype, n_embd, n_ff), + }; + + tensors.insert( + format!("layers.{i}.attention_norm.weight"), + layer.attention_norm.share(), + ); + + tensors.insert(format!("layers.{i}.attention.wq.weight"), layer.wq.share()); + tensors.insert(format!("layers.{i}.attention.wk.weight"), layer.wk.share()); + tensors.insert(format!("layers.{i}.attention.wv.weight"), layer.wv.share()); + tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); + + tensors.insert( + format!("layers.{i}.ffn_norm.weight"), + layer.ffn_norm.share(), + ); + + tensors.insert( + format!("layers.{i}.feed_forward.w1.weight"), + layer.w1.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w2.weight"), + layer.w2.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w3.weight"), + layer.w3.share(), + ); + + layers.push(layer); + } + + Model { + hparams, + vocabulary, + tok_embeddings, + norm, + output, + layers, + tensors, + _context: context, + } + } + + /// Load the model from `path` with `n_context_tokens` context tokens. + /// + /// The status of the loading process will be reported through `load_progress_callback`. + pub fn load( + path: impl AsRef, + n_context_tokens: usize, + load_progress_callback: impl FnMut(LoadProgress), + ) -> Result { + loader::load(path, n_context_tokens, load_progress_callback) + } + + /// Starts a new `InferenceSession` for this model. + pub fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession { + InferenceSession::new( + params, + self.hparams.n_ctx, + self.hparams.n_layer, + self.hparams.n_embd, + self.hparams.n_vocab, + ) + } + + /// Evaluates the transformer. + /// + /// The provided `output_request` struct lets you specify which additional + /// data you are interested in fetching from the transformer. Setting a + /// field to a `Some` value will clear and fill the provided vector with + /// data. The provided vector will be resized to the exact output size. + pub fn evaluate( + &self, + session: &mut InferenceSession, + params: &InferenceParameters, + input_tokens: &[TokenId], + output_request: &mut EvaluateOutputRequest, + ) { + let n = input_tokens.len(); + let n_past = session.n_past; + let n_threads = params.n_threads; + + let memk_elsize = session.memory_k.element_size(); + let memv_elsize = session.memory_v.element_size(); + + let Hyperparameters { + n_vocab, + n_ctx, + n_embd, + n_mult: _, + n_head, + n_layer, + n_rot, + f16_: _, + } = self.hparams; + + // For the first run, we need to guess a maximum buffer size so we can measure + // the actual memory consumption of the temporary ggml context. + // + // These numbers are from `llama.cpp`, and could potentially be more efficient. + let mut buf_size = { + let buf_size_mb = if n_layer >= 80 { + 1536 + } else if n_layer >= 60 { + 1280 + } else { + 1024 + }; + buf_size_mb * 1024 * 1024 + }; + if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { + // add 10% to account for ggml object overhead + buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; + }; + let ctx0 = ggml::Context::init(buf_size); + + let mut gf = ggml::ComputationGraph::new(n_threads); + + let embd = ctx0.new_tensor_1d(ggml::Type::I32, n); + unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; + + let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); + + for il in 0..n_layer { + let input_self_attention = input_layer.share(); + let mut current: ggml::Tensor; + + ctx0.use_scratch(Some(&mut session.scratch[0])); + + // norm + { + current = ctx0.op_rms_norm(&input_layer); + + // cur = attention_norm * cur + current = ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), + ¤t, + ); + } + + // self-attention + { + // compute Q and K and RoPE them + let q_current = ctx0.op_rope( + &ctx0.op_reshape_3d( + &ctx0.op_mul_mat(&self.layers[il].wq, ¤t), + n_embd / n_head, + n_head, + n, + ), + n_past, + n_rot, + 0, + ); + let k_current = ctx0.op_rope( + &ctx0.op_reshape_3d( + &ctx0.op_mul_mat(&self.layers[il].wk, ¤t), + n_embd / n_head, + n_head, + n, + ), + n_past, + n_rot, + 0, + ); + + // store key and value to memory + { + // compute the transposed [N, n_embd] V matrix + let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d( + &ctx0.op_mul_mat(&self.layers[il].wv, ¤t), + n_embd, + n, + )); + + let k = ctx0.op_view_1d( + &session.memory_k, + n * n_embd, + (memk_elsize * n_embd) * (il * n_ctx + n_past), + ); + + let v = ctx0.op_view_2d( + &session.memory_v, + n, + n_embd, + n_ctx * memv_elsize, + (il * n_ctx) * memv_elsize * n_embd + n_past * memv_elsize, + ); + + // important: storing RoPE-ed version of K in the KV cache! + gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); + } + + let q = ctx0.op_permute(&q_current, 0, 2, 1, 3); + + let k = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &session.memory_k, + (n_past + n) * n_embd, + il * n_ctx * memk_elsize * n_embd, + ), + n_embd / n_head, + n_head, + n_past + n, + ), + 0, + 2, + 1, + 3, + ); + + // K * Q + let k_q = ctx0.op_mul_mat(&k, &q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + let k_q_scaled = ctx0.op_scale( + &k_q, + &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), + ); + + // KQ_masked = mask_past(KQ_scaled) + let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled, n_past); + + // KQ = soft_max(KQ_masked) + let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); + + // split cached V into n_head heads + let v = ctx0.op_view_3d( + &session.memory_v, + n_past + n, + n_embd / n_head, + n_head, + n_ctx * memv_elsize, + n_ctx * memv_elsize * n_embd / n_head, + il * n_ctx * memv_elsize * n_embd, + ); + + let k_q_v = ctx0.op_mul_mat(&v, &k_q_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + current = ctx0.op_cpy( + &k_q_v_merged, + &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n), + ); + + // projection (no bias) + current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); + } + + ctx0.use_scratch(Some(&mut session.scratch[1])); + + let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); + + // feed-forward network + { + // norm + { + current = ctx0.op_rms_norm(&input_feed_forward); + + // cur = ffn_norm*cur + current = ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t), + ¤t, + ); + } + + let tmp = ctx0.op_mul_mat(&self.layers[il].w3, ¤t); + + current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); + + // SILU activation + current = ctx0.op_silu(¤t); + + current = ctx0.op_mul(¤t, &tmp); + + current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); + } + + current = ctx0.op_add(¤t, &input_feed_forward); + + // input for next layer + input_layer = current; + } + + ctx0.use_scratch(Some(&mut session.scratch[0])); + + // Used at the end to optionally extract the embeddings. + let embeddings_tensor; + + // norm + { + input_layer = ctx0.op_rms_norm(&input_layer); + + // inpL = norm*inpL + input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer); + embeddings_tensor = input_layer.share(); + } + + // lm_head + { + input_layer = ctx0.op_mul_mat(&self.output, &input_layer); + } + + ctx0.use_scratch(None); + + // logits -> probs + // inpL = ctx0.op_soft_max(&inpL); + + // run the computation + gf.build_forward_expand(&input_layer); + ctx0.graph_compute(&mut gf); + + // return result for just the last token + // SAFETY: yolo + assert_eq!(session.last_logits.len(), n_vocab); + unsafe { + input_layer.read_data( + n_vocab * (n - 1) * std::mem::size_of::(), + bytemuck::cast_slice_mut(&mut session.last_logits), + ) + }; + + // Extract logits + if let Some(all_logits) = &mut output_request.all_logits { + all_logits.resize(n_vocab * n, 0.0); + // SAFETY: Tensor data can be read (properly aligned, initialized, + // data will not be mutated or otherwise aliased during the copy), + // and we're not reading past the end of the tensor data. + assert_eq!(input_layer.nelements(), n_vocab * n); + unsafe { + input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); + } + } + + // Extract embeddings + if let Some(embeddings) = &mut output_request.embeddings { + embeddings.resize(n_embd * n, 0.0); + // SAFETY: Same rationale as for the "Extract logits" section applies. + assert_eq!(embeddings_tensor.nelements(), n_embd * n); + unsafe { + embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings)); + } + } + + // Adjust the required memory per token if we didn't know that already + if session.mem_per_token == 0 { + session.mem_per_token = ctx0.used_mem() / n; + } + + // Adjust n_past to new length. + session.n_past += input_tokens.len(); + } + + /// Returns the vocabulary used by this model. + pub fn vocabulary(&self) -> &Vocabulary { + &self.vocabulary + } + + pub(crate) fn tensors(&self) -> &HashMap { + &self.tensors + } +} + +/// The hyperparameters of the model. +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)] +pub struct Hyperparameters { + /// n_vocab + pub n_vocab: usize, + /// n_ctx + pub n_ctx: usize, + /// n_embd + pub n_embd: usize, + /// n_mult + pub n_mult: usize, + /// n_head + pub n_head: usize, + /// n_layer + pub n_layer: usize, + /// n_rot + pub n_rot: usize, + /// f16_ + pub f16_: u32, +} + +struct Layer { + attention_norm: ggml::Tensor, + + wq: ggml::Tensor, + wk: ggml::Tensor, + wv: ggml::Tensor, + wo: ggml::Tensor, + + // normalization + ffn_norm: ggml::Tensor, + + // ff + w1: ggml::Tensor, + w2: ggml::Tensor, + w3: ggml::Tensor, +} diff --git a/llama-rs/src/util.rs b/llama-rs/src/util.rs index 541bb529..3eb8f06d 100644 --- a/llama-rs/src/util.rs +++ b/llama-rs/src/util.rs @@ -2,6 +2,73 @@ use std::path::{Path, PathBuf}; use crate::LoadError; +/// NOTE: The original code relies in promotion rules and automatic cast between +/// int to float. What we do instead is use this macro to convert every term of +/// the multiplication to f64, which should have enough precision bits to hold +/// the final value, then cast to usize. I have observed a discrepancy between +/// the ctx_size found using this code, and the one in llama.cpp. The number for +/// rust ends up being slightly lower, but no "out of memory" errors are +/// reported by ggml. +macro_rules! mulf { + ($term:expr, $($terms:expr),*) => { + usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap() + }; +} + +pub(crate) use mulf; + +/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. +/// +/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 +/// from multiple tokens. This helps alleviate that issue. +#[derive(Clone, PartialEq, Default)] +pub struct TokenUtf8Buffer(Vec); +impl TokenUtf8Buffer { + /// Create a new buffer. + pub const fn new() -> Self { + Self(vec![]) + } + + /// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text, + /// it is returned and the buffer is cleared for next use. + pub fn push(&mut self, token: &[u8]) -> Option { + self.0.extend_from_slice(token); + match std::str::from_utf8(&self.0) { + Ok(s) => { + let out = s.to_owned(); + self.0 = vec![]; + Some(out) + } + Err(..) => { + for i in 1..self.0.len() { + let slice = &self.0[i..]; + if slice.is_empty() { + break; + } + + if let Ok(s) = std::str::from_utf8(slice) { + let out = s.to_owned(); + self.0 = vec![]; + return Some(out); + } + } + None + } + } + } + + /// Adapt a `&str` callback so that it can be used in a `&[u8]` context. + pub fn adapt_callback<'a, E: std::error::Error + 'static>( + mut callback: impl FnMut(&str) -> Result<(), E> + 'a, + ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a { + let mut buffer = Self::new(); + move |token| match buffer.push(token) { + Some(tokens) => callback(&tokens), + None => Ok(()), + } + } +} + pub(crate) fn find_all_model_files(main_path: &Path) -> Result, LoadError> { Ok(collect_related_paths( main_path, @@ -67,4 +134,26 @@ mod tests { let output_paths = collect_related_paths(&main_path, directory_paths.into_iter()); assert_eq!(expected_paths.as_slice(), output_paths); } + + #[test] + fn test_valid_utf8() { + let mut buffer = TokenUtf8Buffer::new(); + assert_eq!(buffer.push(b"hello").as_deref(), Some("hello")); + assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€")); + } + + #[test] + fn test_partial_utf8() { + let mut buffer = TokenUtf8Buffer::new(); + assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); + assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); + } + + #[test] + fn test_invalid_prelude_for_valid_utf8() { + let mut buffer = TokenUtf8Buffer::new(); + assert_eq!(buffer.push(&[0xD8]).as_deref(), None); + assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); + assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); + } } diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs new file mode 100644 index 00000000..80e619c7 --- /dev/null +++ b/llama-rs/src/vocabulary.rs @@ -0,0 +1,152 @@ +use std::{collections::HashMap, str::FromStr}; + +use crate::InferenceError; + +/// The identifier of a token in a vocabulary. +pub type TokenId = i32; +pub(crate) type Token = Vec; +pub(crate) type TokenScore = f32; + +/// The vocabulary used by a model. +#[derive(Debug, Clone)] +pub struct Vocabulary { + /// Maps every integer (index) token id to its corresponding token + pub(crate) id_to_token: Vec, + + /// Maps every integer (index) token id to corresponding score + pub(crate) id_to_token_score: Vec, + + /// Maps a token to a token id + pub(crate) token_to_id: HashMap, + + /// The longest token in this vocabulary + pub(crate) max_token_length: usize, +} +impl Vocabulary { + pub(crate) fn token(&self, idx: usize) -> &[u8] { + &self.id_to_token[idx] + } + + // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece + /// Tokenize a `text` with this vocabulary. + /// + /// `bos` controls whether a beginning-of-string token should be inserted. + pub fn tokenize<'a>( + &'a self, + text: &str, + bos: bool, + ) -> Result, InferenceError> { + let len = text.len(); + + let mut score = vec![0usize; len + 1]; + let mut prev = vec![TokenId::default(); len + 1]; + + for i in 0..len { + let max_len = (len - i).min(self.max_token_length); + for sub_len in 1..=max_len { + let sub = &text.as_bytes()[i..i + sub_len]; + let token = self.token_to_id.get(sub); + + if let Some(token) = token { + let token_score = sub.len() * sub.len(); + let local_score = score[i] + token_score; + let next = i + sub_len; + + if score[next] < local_score { + score[next] = local_score; + prev[next] = *token; + } + } + } + } + + // Backward pass + let mut res = vec![]; + let mut i = len; + while i > 0 { + let token_id = prev[i]; + if token_id == 0 { + return Err(InferenceError::TokenizationFailed); + } + let token = self.id_to_token[token_id as usize].as_slice(); + res.push((token, token_id)); + i -= token.len(); + } + + if bos { + // TODO: replace with vocab.bos + res.push((&[], 1)); + } + + // Pieces are in reverse order so correct that + res.reverse(); + + Ok(res) + } +} + +#[derive(Default, Clone, Debug, PartialEq)] +/// A list of tokens to bias during the process of inferencing. +/// +/// When a biased token is encountered, the bias will be used +/// instead of the inferred logit during the sampling process. +/// +/// This can be used to disable the generation of responses +/// with specific tokens by setting their corresponding bias +/// to -1.0. +pub struct TokenBias(Vec<(TokenId, f32)>); + +impl TokenBias { + /// Create a [TokenBias] from an existing `Vec`. + pub fn new(mut v: Vec<(TokenId, f32)>) -> Self { + v.sort_by_cached_key(|(tid, _)| *tid); + v.dedup_by_key(|(tid, _)| *tid); + Self(v) + } + + /// Retrieves the bias for a given token, if available. + pub fn get(&self, tid: TokenId) -> Option { + self.0 + .binary_search_by_key(&tid, |(tid, _)| *tid) + .map(|idx| self.0[idx].1) + .ok() + } +} + +impl FromStr for TokenBias { + type Err = String; + + /// A comma separated list of token biases. The list should be in the format + /// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a + /// floating point number. + /// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1 + /// (start of document) and 2 (end of document) to -1.0 which effectively + /// disables the model from generating responses containing those token IDs. + fn from_str(s: &str) -> Result { + let x = s + .split(',') + .map(|kv| { + let (k, v) = kv + .trim() + .split_once('=') + .ok_or_else(|| "Missing '=' in bias item".to_owned())?; + let tid: TokenId = k + .trim() + .parse() + .map_err(|e: std::num::ParseIntError| e.to_string())?; + let bias: f32 = v + .trim() + .parse() + .map_err(|e: std::num::ParseFloatError| e.to_string())?; + Result::<_, String>::Ok((tid, bias)) + }) + .collect::>()?; + Ok(TokenBias::new(x)) + } +} + +impl std::fmt::Display for TokenBias { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +}