diff --git a/Cargo.lock b/Cargo.lock index 38bd0498..6a31264a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,6 +533,14 @@ dependencies = [ "zstd", ] +[[package]] +name = "llama-loader" +version = "0.1.0" +dependencies = [ + "ggml", + "thiserror", +] + [[package]] name = "llama-rs" version = "0.1.0" @@ -540,6 +548,7 @@ dependencies = [ "bincode", "bytemuck", "ggml", + "llama-loader", "memmap2", "partial_sort", "protobuf", diff --git a/Cargo.toml b/Cargo.toml index 8ea220d8..4c383de8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "ggml-sys", "ggml", + "llama-loader", "llama-rs", "llama-cli", "generate-ggml-bindings" diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 915d5692..5aae51fa 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -440,7 +440,7 @@ impl Tensor { // /// Set the tensor's data pointer (useful for mmap-ed data) // /// // /// # Safety - // /// + // /// // /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. // pub unsafe fn set_data(&self, data_ptr: *mut c_void) { // self.with_alive_ctx(|| { diff --git a/llama-loader/Cargo.toml b/llama-loader/Cargo.toml new file mode 100644 index 00000000..cfc8d48b --- /dev/null +++ b/llama-loader/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "llama-loader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ggml = { path = "../ggml" } +thiserror = "*" diff --git a/llama-rs/src/loader2.rs b/llama-loader/src/lib.rs similarity index 73% rename from llama-rs/src/loader2.rs rename to llama-loader/src/lib.rs index b7804216..35c1e902 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-loader/src/lib.rs @@ -1,15 +1,31 @@ -#![allow(missing_docs)] +//! standalone model loader +//! +//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM. +#![allow(clippy::nonminimal_bool)] -//! standalone model loader +pub mod util; use std::{ io::{BufRead, Seek, SeekFrom}, ops::ControlFlow, }; +use util::*; + +pub type ElementType = ggml::Type; + +/// file type containing the model +#[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] +pub enum ContainerType { + /// legacy format, oldest ggml tensor file format + GGML, + /// also legacy format, newer than GGML, older than GGJT + GGMF, + /// mmap-able format + GGJT, +} -use crate::{loader::has_data_left, ElementType, ModelContainerType}; - -pub(crate) fn decode_element_type(ftype: i32) -> Option { +pub fn decode_element_type(ftype: i32) -> Option { match ftype { 0 => Some(ggml::Type::F32), 1 => Some(ggml::Type::F16), @@ -19,7 +35,7 @@ pub(crate) fn decode_element_type(ftype: i32) -> Option { } } -pub(crate) fn encode_element_type(element_type: ElementType) -> Option { +pub fn encode_element_type(element_type: ElementType) -> Option { match element_type { ggml::Type::F32 => Some(0), ggml::Type::F16 => Some(1), @@ -29,38 +45,9 @@ pub(crate) fn encode_element_type(element_type: ElementType) -> Option { } } -pub(crate) fn read_bytes( - reader: &mut impl BufRead, -) -> Result<[u8; N], std::io::Error> { - let mut bytes = [0u8; N]; - reader.read_exact(&mut bytes)?; - Ok(bytes) -} - -pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_bytes_with_len( - reader: &mut impl BufRead, - len: usize, -) -> Result, std::io::Error> { - let mut bytes = vec![0u8; len]; - reader.read_exact(&mut bytes)?; - Ok(bytes) -} - /// The hyperparameters of the model. #[derive(Debug, Clone)] -pub struct FixedHyperparameters { +pub struct LlamaHyperparameters { pub n_vocab: usize, pub n_embd: usize, pub n_mult: usize, @@ -90,7 +77,7 @@ pub enum LoadError { #[error("unsupported tensor dtype/f16_: {0}")] UnsupportedElementtype(i32), - + /// sanity check failed #[error("invariant broken: {0}")] InvariantBroken(String), @@ -107,11 +94,11 @@ pub struct TensorInfo { #[allow(unused_variables)] pub trait LoadHandler { - fn cb_container_type(&mut self, model_type: ModelContainerType) -> ControlFlow { + fn cb_container_type(&mut self, model_type: ContainerType) -> ControlFlow { ControlFlow::Continue(()) } - fn cb_hyper_parameters(&mut self, hparams: FixedHyperparameters) -> ControlFlow { + fn cb_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow { ControlFlow::Continue(()) } @@ -134,30 +121,30 @@ pub fn load_model_from_reader( handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic - let container_type: ModelContainerType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned, + let container_type: ContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, magic => return Err(LoadError::InvalidMagic(magic)), }; retchk(handler.cb_container_type(container_type))?; // Load format version match container_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => { + ContainerType::GGMF | ContainerType::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion(version)), }; } - ModelContainerType::Unversioned => {} + ContainerType::GGML => {} } // Load hyper params // // NOTE: Field order matters! Data is laid out in the file exactly // in this order. - let hparams = FixedHyperparameters { + let hparams = LlamaHyperparameters { n_vocab: read_i32(&mut reader)?.try_into()?, n_embd: read_i32(&mut reader)?.try_into()?, n_mult: read_i32(&mut reader)?.try_into()?, @@ -174,8 +161,8 @@ pub fn load_model_from_reader( let len = read_u32(&mut reader)?.try_into()?; let token = read_bytes_with_len(&mut reader, len)?; let token_score = match container_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => read_f32(&mut reader)?, - ModelContainerType::Unversioned => { + ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?, + ContainerType::GGML => { // Legacy model, set empty score 0. } @@ -185,12 +172,12 @@ pub fn load_model_from_reader( // Load tensor data match container_type { - ModelContainerType::GGMF | ModelContainerType::Unversioned => { + ContainerType::GGMF | ContainerType::GGML => { let _file_offset = reader.stream_position()?; drop(reader); todo!() } - ModelContainerType::GGJT => load_weights_ggjt(&mut reader, handler), + ContainerType::GGJT => load_weights_ggjt(&mut reader, handler), } } @@ -238,23 +225,28 @@ fn load_weights_ggjt( } let tensor_info = TensorInfo { - name, dims, n_dims, n_elements, ftype, + name, + dims, + n_dims, + n_elements, + ftype, }; // load tensor weights let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; reader.seek(SeekFrom::Start(offset_aligned))?; - + let type_size = ggml::type_size(ftype); let buf = retchk(handler.tensor_buffer(tensor_info))?; let buf_len = buf.len(); if !(buf_len == type_size * n_elements) { - return Err(LoadError::InvariantBroken(format!("{buf_len} == {type_size} * {n_elements}"))); + return Err(LoadError::InvariantBroken(format!( + "{buf_len} == {type_size} * {n_elements}" + ))); } reader.read_exact(buf)?; } Ok(()) } - diff --git a/llama-loader/src/util.rs b/llama-loader/src/util.rs new file mode 100644 index 00000000..06e5312f --- /dev/null +++ b/llama-loader/src/util.rs @@ -0,0 +1,33 @@ +use std::io::BufRead; + +pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { + let mut bytes = [0u8; N]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +pub fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_bytes_with_len( + reader: &mut impl BufRead, + len: usize, +) -> Result, std::io::Error> { + let mut bytes = vec![0u8; len]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +// NOTE: Implementation from #![feature(buf_read_has_data_left)] +pub fn has_data_left(reader: &mut impl BufRead) -> Result { + reader.fill_buf().map(|b| !b.is_empty()) +} diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 302a1389..d6d42087 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,6 +8,7 @@ rust-version = "1.65" [dependencies] ggml = { path = "../ggml" } +llama-loader = { path = "../llama-loader" } bytemuck = "1.13.1" partial_sort = "0.2.0" diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 450150b5..3a57b168 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -16,7 +16,8 @@ use std::{ vec, }; -use crate::{util, Hyperparameters, Vocabulary, loader2::encode_element_type}; +use crate::{util, Hyperparameters, Vocabulary}; +use llama_loader::encode_element_type; /// Converts a `pth` file to a `ggml` file. pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index cefa88f5..593cdd11 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -4,14 +4,13 @@ #[cfg(feature = "convert")] pub mod convert; mod loader; -pub mod loader2; mod util; use core::slice; use std::{ collections::HashMap, fmt::Display, - io::{BufRead, Read, Seek, SeekFrom}, + io::Seek, path::{Path, PathBuf}, str::FromStr, time, @@ -25,7 +24,7 @@ use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; -use crate::loader2::decode_element_type; +use llama_loader::{decode_element_type, ContainerType}; /// dummy struct #[cfg(not(feature = "mmap"))] @@ -76,18 +75,6 @@ struct Layer { w3: ggml::Tensor, } -/// file type containing the model -#[derive(Debug, PartialEq, Clone, Copy)] -#[allow(clippy::upper_case_acronyms)] -pub enum ModelContainerType { - /// older than `GGJT` - GGMF, - /// mmap-able format - GGJT, - /// oldest ggml tensor file format - Unversioned, -} - /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. pub struct Model { @@ -104,7 +91,7 @@ pub struct Model { mmap: Option, - _version: ModelContainerType, + _version: ContainerType, // Must be kept alive for the model _context: ggml::Context, @@ -623,10 +610,10 @@ impl Model { let mut reader = BufReader::new(&file); // Verify magic - let model_type: ModelContainerType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned, + let model_type: ContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -636,13 +623,13 @@ impl Model { // Load format version match model_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => { + ContainerType::GGMF | ContainerType::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion { value: version }), }; } - ModelContainerType::Unversioned => {} + ContainerType::GGML => {} } // ================= @@ -682,23 +669,21 @@ impl Model { for i in 0..hparams.n_vocab { let len = match model_type { // `read_i32` maybe a typo - ModelContainerType::GGMF | ModelContainerType::Unversioned => { - read_i32(&mut reader)? as usize - } - ModelContainerType::GGJT => read_u32(&mut reader)? as usize, + ContainerType::GGMF | ContainerType::GGML => read_i32(&mut reader)? as usize, + ContainerType::GGJT => read_u32(&mut reader)? as usize, }; - let token = read_bytes_with_len(&mut reader, len as usize)?; + let token = read_bytes_with_len(&mut reader, len)?; max_token_length = max_token_length.max(token.len()); id_to_token.push(token.clone()); token_to_id.insert(token, TokenId::try_from(i)?); // Token score, currently unused match model_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => { + ContainerType::GGMF | ContainerType::GGJT => { let score = read_f32(&mut reader)?; id_to_token_score.push(score); } - ModelContainerType::Unversioned => { + ContainerType::GGML => { // Legacy model, set empty score id_to_token_score.push(0.); } @@ -825,7 +810,7 @@ impl Model { }; match model_type { - ModelContainerType::GGMF | ModelContainerType::Unversioned => { + ContainerType::GGMF | ContainerType::GGML => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -835,7 +820,7 @@ impl Model { &model, )? } - ModelContainerType::GGJT => { + ContainerType::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; let ptr = mmap.as_ptr(); model.mmap = Some(mmap); diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index dc2c29bd..dc8d4c18 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,4 +1,11 @@ -use crate::{loader2::decode_element_type, *}; +use std::{ + io::{BufRead, Read, Seek, SeekFrom}, + path::Path, +}; + +use crate::ElementType; +use crate::{util, LoadError, LoadProgress, Model}; +use llama_loader::decode_element_type; pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { let mut bytes = [0u8; N];