From bdbea689c5ec17916dc86092d44ad8d55cf1b244 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 18:22:44 +0000 Subject: [PATCH 01/42] Add loader stub for GGJT --- ggml/src/lib.rs | 6 +- llama-rs/src/lib.rs | 322 +++++------------------------------------ llama-rs/src/loader.rs | 293 +++++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+), 288 deletions(-) create mode 100644 llama-rs/src/loader.rs diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 76a7e4ab..f8331bb5 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -14,8 +14,10 @@ use std::{ sync::{Arc, Weak}, }; -/// Magic constant for `ggml` files (versioned). -pub const FILE_MAGIC: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggmf). +pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggjt). +pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; /// Magic constant for `ggml` files (unversioned). pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c; diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 14553379..ed68cf0e 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,6 +1,8 @@ #![deny(missing_docs)] //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. +mod loader; + use core::slice; use std::{ collections::HashMap, @@ -580,6 +582,7 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl Fn(LoadProgress), ) -> Result<(Model, Vocabulary), LoadError> { + use loader::*; use std::fs::File; use std::io::BufReader; @@ -593,46 +596,11 @@ impl Model { })?, ); - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) - } - // Verify magic - let is_legacy_model: bool = match read_u32(&mut reader)? { - ggml::FILE_MAGIC => false, - ggml::FILE_MAGIC_UNVERSIONED => true, + let model_type: ModelType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelType::GGMF, + ggml::FILE_MAGIC_GGJT => ModelType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -641,12 +609,14 @@ impl Model { }; // Load format version - if !is_legacy_model { - #[allow(unused_variables)] - let version: u32 = match read_u32(&mut reader)? { - ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), - }; + match model_type { + ModelType::GGMF | ModelType::GGJT => { + let _version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; + } + ModelType::Unversioned => {} } // ================= @@ -681,8 +651,12 @@ impl Model { let mut max_token_length = 0; for i in 0..hparams.n_vocab { - let len = read_i32(&mut reader)?; - if let Ok(word) = read_string(&mut reader, len as usize) { + let len = match model_type { + // `read_i32` maybe a typo + ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, + ModelType::GGJT => read_u32(&mut reader)? as usize, + }; + if let Ok(word) = read_string(&mut reader, len) { max_token_length = max_token_length.max(word.len()); id_to_token.push(word.clone()); token_to_id.insert(word, TokenId::try_from(i)?); @@ -692,13 +666,16 @@ impl Model { } // Token score, currently unused - if !is_legacy_model { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); + match model_type { + ModelType::GGMF | ModelType::GGJT => { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } + ModelType::Unversioned => { + // Legacy model, set empty score + id_to_token_score.push(0.); } - } else { - // Legacy model, set empty score - id_to_token_score.push(0.); } } @@ -825,240 +802,13 @@ impl Model { } }; - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); - - let paths = util::find_all_model_files(main_path)?; - let n_parts = paths.len(); - - for (i, part_path) in paths.into_iter().enumerate() { - let part_id = i; - - load_progress_callback(LoadProgress::PartLoading { - file: &part_path, - current_part: i, - total_parts: n_parts, - }); - - let mut part_reader = BufReader::new(File::open(&part_path)?); - - // Skip metadata - part_reader.seek(SeekFrom::Start(file_offset))?; - - let mut total_size = 0; - let mut n_tensors = 0; - - // Load weights - loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { - break; - } - - let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; - let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; - - if n_dims == 1 || n_parts == 1 { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if part_id == 0 { - // SAFETY: yolo, same as original code - let slice = unsafe { - let data = tensor.data(); - std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) - }; - part_reader.read_exact(slice)?; - } else { - part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; - } - - total_size += tensor.nbytes(); - } else { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) - != tensor.nbytes() / n_parts - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if split_type == 0 { - let np0 = ne[0]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - assert_eq!(row_size, tensor.get_nb()[1]); - - for i1 in 0..ne[1] { - let offset_row = i1 as usize * row_size; - let offset = offset_row - + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset); - let slice = std::slice::from_raw_parts_mut( - ptr as *mut u8, - row_size / n_parts, - ); - part_reader.read_exact(slice)?; - } - } - } else { - let np1 = ne[1]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - for i1 in 0..ne[1] { - let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset_row); - let slice = - std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); - part_reader.read_exact(slice)?; - } - } - } - - total_size += tensor.nbytes() / n_parts; - } - - n_tensors += 1; - load_progress_callback(LoadProgress::PartTensorLoaded { - file: &part_path, - current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors.len(), - }); + match model_type { + ModelType::GGMF | ModelType::Unversioned => { + load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)? + } + ModelType::GGJT => { + load_weights_ggjt(reader, main_path, load_progress_callback, &model)? } - - load_progress_callback(LoadProgress::PartLoaded { - file: &part_path, - byte_size: total_size, - tensor_count: n_tensors.try_into()?, - }); } Ok((model, vocab)) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs new file mode 100644 index 00000000..4269538d --- /dev/null +++ b/llama-rs/src/loader.rs @@ -0,0 +1,293 @@ +use std::{fs::File, io::BufReader}; + +use crate::*; + +pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; + Ok(bytes) +} + +pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +/// Helper function. Reads a string from the buffer and returns it. +pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) +} + +#[derive(PartialEq)] +pub(crate) enum ModelType { + GGMF, + GGJT, + Unversioned, +} + +pub(crate) fn load_weights_ggmf_or_unversioned( + mut reader: std::io::BufReader, + main_path: &Path, + load_progress_callback: impl Fn(LoadProgress), + model: &Model, +) -> Result<(), LoadError> { + let file_offset = reader.stream_position()?; + drop(reader); + + let paths = util::find_all_model_files(main_path)?; + + let n_parts = paths.len(); + Ok(for (i, part_path) in paths.into_iter().enumerate() { + let part_id = i; + + load_progress_callback(LoadProgress::PartLoading { + file: &part_path, + current_part: i, + total_parts: n_parts, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + // NOTE: Implementation from #![feature(buf_read_has_data_left)] + let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; + + if is_eof { + break; + } + + let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_u32(&mut part_reader)?; + + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(&mut part_reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + + let tensor_name = read_string(&mut part_reader, length as usize)?; + + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); + }; + + // split_type = 0: split by columns + // split_type = 1: split by rows + // + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] + || tensor.get_ne()[1] != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.get_ne()[0] != ne[0] + || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + let bpe = match ftype { + 0 => ggml::type_size(ggml::Type::F32), + 1 => ggml::type_size(ggml::Type::F16), + 2 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::Type::Q4_0) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::Type::Q4_1) + } + _ => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: part_path, + }) + } + }; + + if n_dims == 1 || n_parts == 1 { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + let data = tensor.data(); + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) + != tensor.nbytes() / n_parts + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size / n_parts); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts; + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: &part_path, + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors.len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: &part_path, + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + }) +} + +pub(crate) fn load_weights_ggjt( + mut reader: std::io::BufReader, + main_path: &Path, + load_progress_callback: impl Fn(LoadProgress), + model: &Model, +) -> Result<(), LoadError> { + todo!("GGJT load weights"); +} From b0a666fcc340ab8185780ae6d293d6bccbe81447 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 19:52:14 +0000 Subject: [PATCH 02/42] Add loading code for ggjt Now it can load the model, but it's not working --- Cargo.lock | 313 +++++++++++++++++++++++++-------------- ggml/src/lib.rs | 10 +- llama-rs/Cargo.toml | 3 +- llama-rs/src/lib.rs | 25 +++- llama-rs/src/loader.rs | 323 +++++++++++++++++++++++++---------------- 5 files changed, 435 insertions(+), 239 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ac1053c..38bd0498 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,46 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-wincon", + "concolor-override", + "concolor-query", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" + +[[package]] +name = "anstyle-parse" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-wincon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +dependencies = [ + "anstyle", + "windows-sys 0.45.0", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -103,40 +143,45 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.8" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" dependencies = [ - "bitflags", + "clap_builder", "clap_derive", - "clap_lex", - "is-terminal", "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", "strsim", - "termcolor", ] [[package]] name = "clap_derive" -version = "4.1.8" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", - "proc-macro-error", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] name = "clap_lex" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" -dependencies = [ - "os_str_bytes", -] +checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" [[package]] name = "clipboard-win" @@ -149,6 +194,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "concolor-override" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.7" @@ -261,13 +321,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.2.8" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" dependencies = [ "errno-dragonfly", "libc", - "winapi", + "windows-sys 0.45.0", ] [[package]] @@ -292,13 +352,13 @@ dependencies = [ [[package]] name = "fd-lock" -version = "3.0.10" +version = "3.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ef1a30ae415c3a691a4f41afddc2dbcd6d70baf338368d85ebc1e8ed92cedb9" +checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -310,9 +370,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -378,24 +438,25 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "io-lifetimes" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ + "hermit-abi 0.3.1", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -436,9 +497,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.140" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" [[package]] name = "libloading" @@ -452,9 +513,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "llama-cli" @@ -479,6 +540,7 @@ dependencies = [ "bincode", "bytemuck", "ggml", + "memmap2", "partial_sort", "protobuf", "rand", @@ -510,6 +572,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.8.0" @@ -572,12 +643,6 @@ version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" -[[package]] -name = "os_str_bytes" -version = "6.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" - [[package]] name = "partial_sort" version = "0.2.0" @@ -602,35 +667,11 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] @@ -643,9 +684,9 @@ checksum = "8e86d370532557ae7573551a1ec8235a0f8d6cb276c7c9e6aa490b511c447485" [[package]] name = "quote" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5308e8208729c3e1504a6cfad0d5daacc4614c9a2e65d1ea312a34b5cb00fe84" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -734,9 +775,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ "aho-corasick", "memchr", @@ -745,9 +786,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.28" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "rust_tokenizers" @@ -776,16 +817,16 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.36.9" +version = "0.37.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +checksum = "1aef160324be24d31a62147fae491c14d2204a3865c7ca8c3b0d7f7bcb3ea635" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -831,9 +872,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] @@ -849,20 +890,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2", "quote", - "syn 2.0.10", + "syn 2.0.13", ] [[package]] name = "serde_json" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "itoa", "ryu", @@ -945,9 +986,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.10" +version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aad1363ed6d37b84299588d62d3a7d95b5a5c2d9aad5c85609fda12afaa1f40" +checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" dependencies = [ "proc-macro2", "quote", @@ -965,22 +1006,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] @@ -1040,12 +1081,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1100,7 +1135,16 @@ version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-targets", + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", ] [[package]] @@ -1109,13 +1153,28 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] [[package]] @@ -1124,42 +1183,84 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + [[package]] name = "windows_i686_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + [[package]] name = "windows_i686_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "zstd" version = "0.12.3+zstd.1.5.2" @@ -1171,9 +1272,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "6.0.4+zstd.1.5.4" +version = "6.0.5+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7afb4b54b8910cf5447638cb54bf4e8a65cbedd783af98b98c62ffe91f185543" +checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" dependencies = [ "libc", "zstd-sys", @@ -1181,9 +1282,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.7+zstd.1.5.4" +version = "2.0.8+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" +checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" dependencies = [ "cc", "libc", diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index f8331bb5..69204c5e 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -429,13 +429,21 @@ impl Tensor { /// # Safety /// /// The data must not be mutated while being read from. - pub unsafe fn data(&self) -> *mut c_void { + pub unsafe fn data(&self) -> *const c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive unsafe { *self.ptr.as_ptr() }.data }) } + /// Set the tensor's data pointer (useful for mmap-ed data) + pub unsafe fn set_data(&self, data_ptr: *mut c_void) { + self.with_alive_ctx(|| { + // SAFETY: The with_alive_call guarantees the context is alive + unsafe { *self.ptr.as_ptr() }.data = data_ptr; + }) + } + /// Number of elements in this tensor. pub fn nelements(&self) -> usize { self.with_alive_ctx(|| { diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 0e48cd58..b2e3aa15 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -16,6 +16,7 @@ rand = { workspace = true } serde = { version = "1.0.156", features = ["derive"] } serde_bytes = "0.11" bincode = "1.3.3" +memmap2 = "0.5.10" # Used for the `convert` feature serde_json = { version = "1.0.94", optional = true } @@ -23,4 +24,4 @@ protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index ed68cf0e..417f7ea8 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -14,6 +14,7 @@ use std::{ }; use serde::Deserialize; +use memmap2::Mmap; use thiserror::Error; use partial_sort::PartialSort; @@ -73,6 +74,8 @@ pub struct Model { tensors: HashMap, + mmap: Option, + // Must be kept alive for the model _context: ggml::Context, } @@ -505,7 +508,7 @@ pub enum LoadError { /// The name of the tensor. tensor_name: String, /// The format type that was encountered. - ftype: u32, + ftype: i32, /// The path that failed. path: PathBuf, }, @@ -588,12 +591,13 @@ impl Model { let main_path = path.as_ref(); - let mut reader = - BufReader::new( - File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { source: e, path: main_path.to_owned(), - })?, + })?; + let mut reader = + BufReader::new( + &file, ); // Verify magic @@ -735,7 +739,7 @@ impl Model { // Initialize the context let context = ggml::Context::init(ctx_size); - let model = { + let mut model = { let mut tensors = HashMap::new(); let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); @@ -799,15 +803,20 @@ impl Model { layers, tensors, _context: context, + mmap: None, } }; match model_type { ModelType::GGMF | ModelType::Unversioned => { - load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)? + let file_offset = reader.stream_position()?; + drop(reader); + load_weights_ggmf_or_unversioned(file_offset, main_path, load_progress_callback, &model)? } ModelType::GGJT => { - load_weights_ggjt(reader, main_path, load_progress_callback, &model)? + let mmap = unsafe { Mmap::map(&file)? }; + load_weights_ggjt(&mut reader, &mmap, main_path, load_progress_callback, &model)?; + model.mmap = Some(mmap); } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 4269538d..da65ec6c 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -26,7 +26,7 @@ pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { } /// Helper function. Reads a string from the buffer and returns it. -pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result { +pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { let mut buf = vec![0; len]; reader .read_exact(&mut buf) @@ -38,6 +38,11 @@ pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result Result { + reader.fill_buf().map(|b| !b.is_empty()) +} + #[derive(PartialEq)] pub(crate) enum ModelType { GGMF, @@ -46,14 +51,11 @@ pub(crate) enum ModelType { } pub(crate) fn load_weights_ggmf_or_unversioned( - mut reader: std::io::BufReader, + file_offset: u64, main_path: &Path, load_progress_callback: impl Fn(LoadProgress), model: &Model, ) -> Result<(), LoadError> { - let file_offset = reader.stream_position()?; - drop(reader); - let paths = util::find_all_model_files(main_path)?; let n_parts = paths.len(); @@ -76,125 +78,23 @@ pub(crate) fn load_weights_ggmf_or_unversioned( // Load weights loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { + if !has_data_left(&mut part_reader)? { break; } let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; + let ftype = read_i32(&mut part_reader)?; + + let (nelements, ne, tensor_name, tensor, split_type, bpe) = load_tensor_header_ggmf( + n_dims, + &mut part_reader, + length, + model, + &part_path, + n_parts, + ftype, + )?; if n_dims == 1 || n_parts == 1 { if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { @@ -283,11 +183,188 @@ pub(crate) fn load_weights_ggmf_or_unversioned( }) } +fn load_tensor_header_ggmf<'a>( + n_dims: usize, + reader: &mut BufReader, + length: i32, + model: &'a Model, + path: &Path, + n_parts: usize, + ftype: i32, +) -> Result<(usize, [i64; 2], String, &'a ggml::Tensor, i32, usize), LoadError> { + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let bpe = tensor_type_size(ftype, ne); + let bpe = match bpe { + Some(x) => x, + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; + Ok((nelements, ne, tensor_name, tensor, split_type, bpe)) +} + +fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { + let bpe = match ftype { + 0 => Some(ggml::type_size(ggml::Type::F32)), + 1 => Some(ggml::type_size(ggml::Type::F16)), + 2 => { + assert_eq!(ne[0] % 64, 0); + Some(ggml::type_size(ggml::Type::Q4_0)) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + Some(ggml::type_size(ggml::Type::Q4_1)) + } + _ => None, + }; + bpe +} + pub(crate) fn load_weights_ggjt( - mut reader: std::io::BufReader, - main_path: &Path, + reader: &mut std::io::BufReader<&File>, + mmap: &Mmap, + path: &Path, load_progress_callback: impl Fn(LoadProgress), model: &Model, -) -> Result<(), LoadError> { - todo!("GGJT load weights"); +) -> Result<(), LoadError> +// where R: std::io::Read +{ + let mut loop_i = 0; + let mut total_loaded_bytes = 0; + load_progress_callback(LoadProgress::PartLoading { + file: path, + current_part: 0, + total_parts: 1, + }); + + loop { + if !has_data_left(reader)? { + break; + } + + let n_dims = read_i32(reader)? as usize; + let length = read_i32(reader)?; + let ftype = read_i32(reader)?; + + let mut nelements: usize = 1; + let mut ne = [1i64, 1]; + assert!(n_dims <= ne.len()); + for i in 0..n_dims { + let dim = read_i32(reader)? as usize; + ne[i] = dim as i64; + nelements *= dim; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let tensor_ne = tensor.get_ne(); + if tensor_ne[0] != ne[0] || tensor_ne[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + + _ = tensor_type_size(ftype, ne); + + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & (31 ^ u64::MAX); + unsafe { + let ptr = mmap.as_ptr().offset(offset_aligned as isize); + tensor.set_data(ptr as *mut std::ffi::c_void); + } + let tensor_data_size = tensor.nbytes() as u64; + reader.seek(SeekFrom::Start(offset_aligned + tensor_data_size))?; + total_loaded_bytes += tensor_data_size; + + load_progress_callback(LoadProgress::PartTensorLoaded { + file: path, + current_tensor: loop_i, + tensor_count: model.tensors.len(), + }); + + loop_i += 1; + } + + load_progress_callback(LoadProgress::PartLoaded { + file: path, + byte_size: total_loaded_bytes as usize, + tensor_count: loop_i, + }); + + return Ok(()); } From 9eefdc594ef82b4df69ba8525415a2522131a42b Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:09:40 +0000 Subject: [PATCH 03/42] code cleanup that doesn't change anything --- llama-rs/src/lib.rs | 46 ++++++++++++++++++++++++++++-------------- llama-rs/src/loader.rs | 11 +++++++++- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 417f7ea8..ad240af7 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -592,13 +592,10 @@ impl Model { let main_path = path.as_ref(); let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: main_path.to_owned(), - })?; - let mut reader = - BufReader::new( - &file, - ); + source: e, + path: main_path.to_owned(), + })?; + let mut reader = BufReader::new(&file); // Verify magic let model_type: ModelType = match read_u32(&mut reader)? { @@ -660,13 +657,21 @@ impl Model { ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, ModelType::GGJT => read_u32(&mut reader)? as usize, }; - if let Ok(word) = read_string(&mut reader, len) { - max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, TokenId::try_from(i)?); + let maybe_word = if len > 0 { + read_string(&mut reader, len) } else { - load_progress_callback(LoadProgress::BadToken { index: i }); - id_to_token.push("�".to_string()); + Ok("".into()) + }; + match maybe_word { + Ok(word) => { + max_token_length = max_token_length.max(word.len()); + id_to_token.push(word.clone()); + token_to_id.insert(word, TokenId::try_from(i)?); + } + Err(_e) => { + load_progress_callback(LoadProgress::BadToken { index: i }); + id_to_token.push("�".to_string()); + } } // Token score, currently unused @@ -811,11 +816,22 @@ impl Model { ModelType::GGMF | ModelType::Unversioned => { let file_offset = reader.stream_position()?; drop(reader); - load_weights_ggmf_or_unversioned(file_offset, main_path, load_progress_callback, &model)? + load_weights_ggmf_or_unversioned( + file_offset, + main_path, + load_progress_callback, + &model, + )? } ModelType::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; - load_weights_ggjt(&mut reader, &mmap, main_path, load_progress_callback, &model)?; + load_weights_ggjt( + &mut reader, + &mmap, + main_path, + load_progress_callback, + &model, + )?; model.mmap = Some(mmap); } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index da65ec6c..e7326e13 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -339,7 +339,16 @@ pub(crate) fn load_weights_ggjt( }); } - _ = tensor_type_size(ftype, ne); + match tensor_type_size(ftype, ne) { + Some(_) => {}, + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & (31 ^ u64::MAX); From c212c534feb1bb5b9cc9684d672f39a8debdbaaa Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:12:55 +0000 Subject: [PATCH 04/42] more code cleanup --- llama-rs/src/lib.rs | 36 ++++++++++++++++++++++++------------ llama-rs/src/loader.rs | 7 ------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index ad240af7..14053612 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -60,6 +60,15 @@ struct Layer { w3: ggml::Tensor, } + +/// Model Version +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) enum ModelVersion { + GGMF, + GGJT, + Unversioned, +} + /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. pub struct Model { @@ -75,6 +84,8 @@ pub struct Model { tensors: HashMap, mmap: Option, + + version: ModelVersion, // Must be kept alive for the model _context: ggml::Context, @@ -598,10 +609,10 @@ impl Model { let mut reader = BufReader::new(&file); // Verify magic - let model_type: ModelType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned, + let model_type: ModelVersion = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF, + ggml::FILE_MAGIC_GGJT => ModelVersion::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelVersion::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -611,13 +622,13 @@ impl Model { // Load format version match model_type { - ModelType::GGMF | ModelType::GGJT => { + ModelVersion::GGMF | ModelVersion::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion { value: version }), }; } - ModelType::Unversioned => {} + ModelVersion::Unversioned => {} } // ================= @@ -654,8 +665,8 @@ impl Model { for i in 0..hparams.n_vocab { let len = match model_type { // `read_i32` maybe a typo - ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, - ModelType::GGJT => read_u32(&mut reader)? as usize, + ModelVersion::GGMF | ModelVersion::Unversioned => read_i32(&mut reader)? as usize, + ModelVersion::GGJT => read_u32(&mut reader)? as usize, }; let maybe_word = if len > 0 { read_string(&mut reader, len) @@ -676,12 +687,12 @@ impl Model { // Token score, currently unused match model_type { - ModelType::GGMF | ModelType::GGJT => { + ModelVersion::GGMF | ModelVersion::GGJT => { if let Ok(score) = read_f32(&mut reader) { id_to_token_score.push(score); } } - ModelType::Unversioned => { + ModelVersion::Unversioned => { // Legacy model, set empty score id_to_token_score.push(0.); } @@ -809,11 +820,12 @@ impl Model { tensors, _context: context, mmap: None, + version: model_type, } }; match model_type { - ModelType::GGMF | ModelType::Unversioned => { + ModelVersion::GGMF | ModelVersion::Unversioned => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -823,7 +835,7 @@ impl Model { &model, )? } - ModelType::GGJT => { + ModelVersion::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; load_weights_ggjt( &mut reader, diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index e7326e13..602bd080 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -43,13 +43,6 @@ fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } -#[derive(PartialEq)] -pub(crate) enum ModelType { - GGMF, - GGJT, - Unversioned, -} - pub(crate) fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path, From bfaec3ac12282672cddc9d41128dcd2f33daaace Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 13:54:45 +0000 Subject: [PATCH 05/42] minor change in math, tensor loading --- llama-rs/src/loader.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 602bd080..fa5ce73a 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -333,7 +333,7 @@ pub(crate) fn load_weights_ggjt( } match tensor_type_size(ftype, ne) { - Some(_) => {}, + Some(_) => {} None => { return Err(LoadError::InvalidFtype { tensor_name, @@ -344,7 +344,7 @@ pub(crate) fn load_weights_ggjt( }; let offset_curr = reader.stream_position()?; - let offset_aligned: u64 = (offset_curr + 31) & (31 ^ u64::MAX); + let offset_aligned: u64 = (offset_curr + 31) & !31; unsafe { let ptr = mmap.as_ptr().offset(offset_aligned as isize); tensor.set_data(ptr as *mut std::ffi::c_void); From b6044ee09c4035451c3938ff54e7959ec0c93bd3 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 19:23:07 +0000 Subject: [PATCH 06/42] Add non-mmap loader for GGJT --- ggml/src/lib.rs | 2 +- llama-rs/Cargo.toml | 5 ++++- llama-rs/src/lib.rs | 44 +++++++++++++++++++++++++++--------------- llama-rs/src/loader.rs | 39 +++++++++++++++++++++++++++---------- 4 files changed, 62 insertions(+), 28 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 69204c5e..4491473a 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -429,7 +429,7 @@ impl Tensor { /// # Safety /// /// The data must not be mutated while being read from. - pub unsafe fn data(&self) -> *const c_void { + pub unsafe fn data(&self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive unsafe { *self.ptr.as_ptr() }.data diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index b2e3aa15..302a1389 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -16,7 +16,7 @@ rand = { workspace = true } serde = { version = "1.0.156", features = ["derive"] } serde_bytes = "0.11" bincode = "1.3.3" -memmap2 = "0.5.10" +memmap2 = { version = "0.5.10", optional = true } # Used for the `convert` feature serde_json = { version = "1.0.94", optional = true } @@ -25,3 +25,6 @@ rust_tokenizers = { version = "3.1.2", optional = true } [features] convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] + +# broken atm, see https://github.com/rustformers/llama-rs/pull/114#issuecomment-1500337463 +mmap = ["dep:memmap2"] diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 14053612..f206af8e 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -2,6 +2,9 @@ //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. mod loader; +mod util; +#[cfg(feature = "convert")] +pub mod convert; use core::slice; use std::{ @@ -14,18 +17,31 @@ use std::{ }; use serde::Deserialize; -use memmap2::Mmap; use thiserror::Error; - use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; - pub use ggml::Type as ElementType; -#[cfg(feature = "convert")] -pub mod convert; +#[cfg(feature = "mmap")] +use memmap2::Mmap; + +/// dummy struct +#[cfg(not(feature = "mmap"))] +pub(crate) struct Mmap; + +/// dummy impl +#[cfg(not(feature = "mmap"))] +impl Mmap { + pub(crate) unsafe fn map(_: &std::fs::File) -> Result { + Ok(Mmap) + } + pub(crate) fn as_ptr(&self) -> *const u8 { + std::ptr::null() + } +} +// map + -mod util; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) @@ -60,7 +76,6 @@ struct Layer { w3: ggml::Tensor, } - /// Model Version #[derive(Debug, PartialEq, Clone, Copy)] pub(crate) enum ModelVersion { @@ -84,7 +99,7 @@ pub struct Model { tensors: HashMap, mmap: Option, - + version: ModelVersion, // Must be kept alive for the model @@ -665,7 +680,9 @@ impl Model { for i in 0..hparams.n_vocab { let len = match model_type { // `read_i32` maybe a typo - ModelVersion::GGMF | ModelVersion::Unversioned => read_i32(&mut reader)? as usize, + ModelVersion::GGMF | ModelVersion::Unversioned => { + read_i32(&mut reader)? as usize + } ModelVersion::GGJT => read_u32(&mut reader)? as usize, }; let maybe_word = if len > 0 { @@ -837,14 +854,9 @@ impl Model { } ModelVersion::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; - load_weights_ggjt( - &mut reader, - &mmap, - main_path, - load_progress_callback, - &model, - )?; + let ptr = mmap.as_ptr(); model.mmap = Some(mmap); + load_weights_ggjt(&mut reader, ptr, main_path, load_progress_callback, &model)?; } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index fa5ce73a..9155f4e8 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -280,7 +280,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { pub(crate) fn load_weights_ggjt( reader: &mut std::io::BufReader<&File>, - mmap: &Mmap, + mmap_base: *const u8, path: &Path, load_progress_callback: impl Fn(LoadProgress), model: &Model, @@ -343,15 +343,9 @@ pub(crate) fn load_weights_ggjt( } }; - let offset_curr = reader.stream_position()?; - let offset_aligned: u64 = (offset_curr + 31) & !31; - unsafe { - let ptr = mmap.as_ptr().offset(offset_aligned as isize); - tensor.set_data(ptr as *mut std::ffi::c_void); - } - let tensor_data_size = tensor.nbytes() as u64; - reader.seek(SeekFrom::Start(offset_aligned + tensor_data_size))?; - total_loaded_bytes += tensor_data_size; + load_tensor(reader, mmap_base, tensor)?; + + total_loaded_bytes += tensor.nbytes() as u64; load_progress_callback(LoadProgress::PartTensorLoaded { file: path, @@ -370,3 +364,28 @@ pub(crate) fn load_weights_ggjt( return Ok(()); } + +#[cfg(feature = "mmap")] +fn load_tensor(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &ggml::Tensor) -> Result<(), LoadError> { + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + unsafe { + let ptr = mmap_base.offset(offset_aligned as isize); + tensor.set_data(ptr as *mut std::ffi::c_void); + } + reader.seek(SeekFrom::Start(offset_aligned + tensor.nbytes() as u8))?; + Ok(()) +} + +#[cfg(not(feature = "mmap"))] +fn load_tensor<'a>(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &'a ggml::Tensor) -> Result<(), LoadError> { + _ = mmap_base; + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + reader.seek(SeekFrom::Start(offset_aligned))?; + + let buf: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + reader.read_exact(buf)?; + + Ok(()) +} From 1872dda8f342fc99128892bf9fff1a7d01267a63 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 19:37:36 +0000 Subject: [PATCH 07/42] Prefer traits in loader.rs --- llama-rs/src/loader.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 9155f4e8..78228ce4 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,5 +1,3 @@ -use std::{fs::File, io::BufReader}; - use crate::*; pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { @@ -49,6 +47,8 @@ pub(crate) fn load_weights_ggmf_or_unversioned( load_progress_callback: impl Fn(LoadProgress), model: &Model, ) -> Result<(), LoadError> { + use std::{fs::File, io::BufReader}; + let paths = util::find_all_model_files(main_path)?; let n_parts = paths.len(); @@ -178,7 +178,7 @@ pub(crate) fn load_weights_ggmf_or_unversioned( fn load_tensor_header_ggmf<'a>( n_dims: usize, - reader: &mut BufReader, + reader: &mut impl BufRead, length: i32, model: &'a Model, path: &Path, @@ -279,7 +279,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { } pub(crate) fn load_weights_ggjt( - reader: &mut std::io::BufReader<&File>, + reader: &mut (impl BufRead + Seek), mmap_base: *const u8, path: &Path, load_progress_callback: impl Fn(LoadProgress), @@ -344,7 +344,7 @@ pub(crate) fn load_weights_ggjt( }; load_tensor(reader, mmap_base, tensor)?; - + total_loaded_bytes += tensor.nbytes() as u64; load_progress_callback(LoadProgress::PartTensorLoaded { @@ -366,7 +366,11 @@ pub(crate) fn load_weights_ggjt( } #[cfg(feature = "mmap")] -fn load_tensor(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &ggml::Tensor) -> Result<(), LoadError> { +fn load_tensor( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &ggml::Tensor, +) -> Result<(), LoadError> { let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; unsafe { @@ -378,13 +382,18 @@ fn load_tensor(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &ggm } #[cfg(not(feature = "mmap"))] -fn load_tensor<'a>(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &'a ggml::Tensor) -> Result<(), LoadError> { +fn load_tensor<'a>( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &'a ggml::Tensor, +) -> Result<(), LoadError> { _ = mmap_base; let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; reader.seek(SeekFrom::Start(offset_aligned))?; - let buf: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + let buf: &'a mut [u8] = + unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; reader.read_exact(buf)?; Ok(()) From ec1fca7c08447dd87329887473a0b085a5d4d0bd Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 19:51:49 +0000 Subject: [PATCH 08/42] cargo fmt --- llama-rs/src/lib.rs | 12 +++++------- llama-rs/src/loader.rs | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index f206af8e..3e9b8191 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,10 +1,10 @@ #![deny(missing_docs)] //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. -mod loader; -mod util; #[cfg(feature = "convert")] pub mod convert; +mod loader; +mod util; use core::slice; use std::{ @@ -16,11 +16,11 @@ use std::{ time, }; -use serde::Deserialize; -use thiserror::Error; +pub use ggml::Type as ElementType; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; -pub use ggml::Type as ElementType; +use serde::Deserialize; +use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; @@ -41,8 +41,6 @@ impl Mmap { } // map - - /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 78228ce4..912f924c 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -343,7 +343,7 @@ pub(crate) fn load_weights_ggjt( } }; - load_tensor(reader, mmap_base, tensor)?; + load_tensor_ggjt(reader, mmap_base, tensor)?; total_loaded_bytes += tensor.nbytes() as u64; @@ -366,7 +366,7 @@ pub(crate) fn load_weights_ggjt( } #[cfg(feature = "mmap")] -fn load_tensor( +fn load_tensor_ggjt( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, tensor: &ggml::Tensor, @@ -382,7 +382,7 @@ fn load_tensor( } #[cfg(not(feature = "mmap"))] -fn load_tensor<'a>( +fn load_tensor_ggjt<'a>( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, tensor: &'a ggml::Tensor, From cc846aee76b09aa434430ba0ccf9d35c5fe9cbdc Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 20:33:02 +0000 Subject: [PATCH 09/42] cargo clippy --fix --- ggml/src/lib.rs | 6 +++++- llama-rs/src/convert.rs | 4 ++-- llama-rs/src/lib.rs | 5 +++-- llama-rs/src/loader.rs | 15 +++++++++------ 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 4491473a..7a78e448 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -428,7 +428,7 @@ impl Tensor { /// /// # Safety /// - /// The data must not be mutated while being read from. + /// Only `std::slice::from_raw_parts_mut(tensor.data(), tensor.nbytes())` is safe to mutate. pub unsafe fn data(&self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive @@ -437,6 +437,10 @@ impl Tensor { } /// Set the tensor's data pointer (useful for mmap-ed data) + /// + /// # Safety + /// + /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. pub unsafe fn set_data(&self, data_ptr: *mut c_void) { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 285cf3c0..f4f55996 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -28,12 +28,12 @@ pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { let model_files = util::find_all_model_files(model_directory).unwrap(); for (i, _file) in model_files.iter().enumerate() { - let fname_out = model_directory.join(format!("rust-model-{}.bin", element_type)); + let fname_out = model_directory.join(format!("rust-model-{element_type}.bin")); let mut file = File::create(fname_out).expect("Unable to create file"); write_header(file.borrow_mut(), &hparams).unwrap(); write_tokens(file.borrow_mut(), &vocab).unwrap(); - let _fname_model = model_directory.join(format!("consolidated.0{}.pth", i)); + let _fname_model = model_directory.join(format!("consolidated.0{i}.pth")); // Todo process and write variables } } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 3e9b8191..a1aa4775 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -76,6 +76,7 @@ struct Layer { /// Model Version #[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] pub(crate) enum ModelVersion { GGMF, GGJT, @@ -98,7 +99,7 @@ pub struct Model { mmap: Option, - version: ModelVersion, + _version: ModelVersion, // Must be kept alive for the model _context: ggml::Context, @@ -835,7 +836,7 @@ impl Model { tensors, _context: context, mmap: None, - version: model_type, + _version: model_type, } }; diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 912f924c..fc3b2c61 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -52,7 +52,7 @@ pub(crate) fn load_weights_ggmf_or_unversioned( let paths = util::find_all_model_files(main_path)?; let n_parts = paths.len(); - Ok(for (i, part_path) in paths.into_iter().enumerate() { + for (i, part_path) in paths.into_iter().enumerate() { let part_id = i; load_progress_callback(LoadProgress::PartLoading { @@ -173,9 +173,11 @@ pub(crate) fn load_weights_ggmf_or_unversioned( byte_size: total_size, tensor_count: n_tensors.try_into()?, }); - }) + }; + Ok(()) } +#[allow(clippy::type_complexity)] fn load_tensor_header_ggmf<'a>( n_dims: usize, reader: &mut impl BufRead, @@ -262,7 +264,8 @@ fn load_tensor_header_ggmf<'a>( } fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { - let bpe = match ftype { + + match ftype { 0 => Some(ggml::type_size(ggml::Type::F32)), 1 => Some(ggml::type_size(ggml::Type::F16)), 2 => { @@ -274,8 +277,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { Some(ggml::type_size(ggml::Type::Q4_1)) } _ => None, - }; - bpe + } } pub(crate) fn load_weights_ggjt( @@ -307,6 +309,7 @@ pub(crate) fn load_weights_ggjt( let mut nelements: usize = 1; let mut ne = [1i64, 1]; assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] for i in 0..n_dims { let dim = read_i32(reader)? as usize; ne[i] = dim as i64; @@ -362,7 +365,7 @@ pub(crate) fn load_weights_ggjt( tensor_count: loop_i, }); - return Ok(()); + Ok(()) } #[cfg(feature = "mmap")] From bf847dd894b578a29bbec39ecba8d8b5c559b769 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 20:36:09 +0000 Subject: [PATCH 10/42] Remove ggml::Tensor::set_data --- ggml/src/lib.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 7a78e448..9128b354 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -436,17 +436,17 @@ impl Tensor { }) } - /// Set the tensor's data pointer (useful for mmap-ed data) - /// - /// # Safety - /// - /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. - pub unsafe fn set_data(&self, data_ptr: *mut c_void) { - self.with_alive_ctx(|| { - // SAFETY: The with_alive_call guarantees the context is alive - unsafe { *self.ptr.as_ptr() }.data = data_ptr; - }) - } + // /// Set the tensor's data pointer (useful for mmap-ed data) + // /// + // /// # Safety + // /// + // /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. + // pub unsafe fn set_data(&self, data_ptr: *mut c_void) { + // self.with_alive_ctx(|| { + // // SAFETY: The with_alive_call guarantees the context is alive + // unsafe { *self.ptr.as_ptr() }.data = data_ptr; + // }) + // } /// Number of elements in this tensor. pub fn nelements(&self) -> usize { From ea7094ccc86e6e5fa2a17cd318f5cc2185dd5fe6 Mon Sep 17 00:00:00 2001 From: Philpax Date: Fri, 7 Apr 2023 19:01:24 +0200 Subject: [PATCH 11/42] fix(llama): buffer tokens until valid UTF-8 --- llama-cli/src/cli_args.rs | 3 - llama-rs/src/convert.rs | 26 ++++--- llama-rs/src/lib.rs | 157 +++++++++++++++++++++++++++----------- 3 files changed, 129 insertions(+), 57 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index e4b54d35..5dd663db 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -268,9 +268,6 @@ impl ModelLoad { LoadProgress::HyperparametersLoaded(hparams) => { log::debug!("Loaded hyperparameters {hparams:#?}") } - LoadProgress::BadToken { index } => { - log::info!("Warning: Bad token in vocab at index {index}") - } LoadProgress::ContextSize { bytes } => log::info!( "ggml ctx size = {:.2} MB\n", bytes as f64 / (1024.0 * 1024.0) diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index f4f55996..1b7a909c 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -49,11 +49,12 @@ fn load_vocabulary(path: &Path) -> Vocabulary { let mut token_to_id = HashMap::new(); let mut max_token_length = 0; + // TODO: Does the original model use valid UTF-8 for its tokens? This seems a little suspect to me. for (idx, piece) in proto.get_pieces().iter().enumerate() { - let word = piece.get_piece().to_string(); + let word = piece.get_piece().as_bytes(); max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, idx as i32); + id_to_token.push(word.to_owned()); + token_to_id.insert(word.to_owned(), idx as i32); id_to_token_score.push(piece.get_score()); } Vocabulary { @@ -128,13 +129,20 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String fn write_tokens(file: &mut File, vocab: &Vocabulary) -> Result<(), String> { let mut values: Vec = vec![]; for (i, token) in vocab.id_to_token.iter().enumerate() { - let text = match token { - _ if token.contains("") => " \u{2047} ".as_bytes().to_vec(), - _ if token.contains("s>") => vec![], - _ if token.len() == 6 && token.contains("<0x") => { - vec![u8::from_str_radix(&token[3..5], 16).unwrap()] + // TODO: Not sure what the behaviour should be if the token is not valid UTF-8. + // + // Switching to the HF tokenizer should fix this. + let text = if let Ok(token) = std::str::from_utf8(token) { + match token { + _ if token.contains("") => " \u{2047} ".as_bytes().to_vec(), + _ if token.contains("s>") => vec![], + _ if token.len() == 6 && token.contains("<0x") => { + vec![u8::from_str_radix(&token[3..5], 16).unwrap()] + } + _ => token.replace('\u{2581}', " ").as_bytes().to_vec(), } - _ => token.replace('\u{2581}', " ").as_bytes().to_vec(), + } else { + token.clone() }; values.extend((text.len() as i32).to_le_bytes()); values.extend(&text); diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index a1aa4775..df169130 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -319,7 +319,7 @@ impl Display for InferenceStats { } type TokenId = i32; -type Token = String; +type Token = Vec; type TokenScore = f32; /// The vocabulary used by a model. @@ -339,7 +339,7 @@ pub struct Vocabulary { max_token_length: usize, } impl Vocabulary { - fn token(&self, idx: usize) -> &str { + fn token(&self, idx: usize) -> &[u8] { &self.id_to_token[idx] } } @@ -416,14 +416,6 @@ impl std::fmt::Display for TokenBias { pub enum LoadProgress<'a> { /// The hyperparameters have been loaded from the model. HyperparametersLoaded(&'a Hyperparameters), - /// A bad token was encountered during the loading process. - /// - /// This can be ignored, but invalid tokens will be replaced with - /// the `�` character. - BadToken { - /// The index within the vocabulary. - index: usize, - }, /// The context has been created. ContextSize { /// The size of the context. @@ -622,6 +614,48 @@ impl Model { })?; let mut reader = BufReader::new(&file); + fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; + Ok(bytes) + } + + fn read_bytes_with_len( + reader: &mut impl BufRead, + len: usize, + ) -> Result, LoadError> { + let mut bytes = vec![0u8; len]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: len, + })?; + Ok(bytes) + } + + fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) + } + + fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) + } + + fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) + } + + /// Helper function. Reads a string from the buffer and returns it. + fn read_string(reader: &mut BufReader, len: usize) -> Result { + Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) + } + // Verify magic let model_type: ModelVersion = match read_u32(&mut reader)? { ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF, @@ -677,29 +711,11 @@ impl Model { let mut max_token_length = 0; for i in 0..hparams.n_vocab { - let len = match model_type { - // `read_i32` maybe a typo - ModelVersion::GGMF | ModelVersion::Unversioned => { - read_i32(&mut reader)? as usize - } - ModelVersion::GGJT => read_u32(&mut reader)? as usize, - }; - let maybe_word = if len > 0 { - read_string(&mut reader, len) - } else { - Ok("".into()) - }; - match maybe_word { - Ok(word) => { - max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, TokenId::try_from(i)?); - } - Err(_e) => { - load_progress_callback(LoadProgress::BadToken { index: i }); - id_to_token.push("�".to_string()); - } - } + let len = read_i32(&mut reader)?; + let token = read_bytes_with_len(&mut reader, len as usize)?; + max_token_length = max_token_length.max(token.len()); + id_to_token.push(token.clone()); + token_to_id.insert(token, TokenId::try_from(i)?); // Token score, currently unused match model_type { @@ -1225,7 +1241,7 @@ impl InferenceSession { vocab: &Vocabulary, params: &InferenceParameters, prompt: &str, - callback: impl Fn(&str) -> Result<(), E>, + mut callback: impl FnMut(&[u8]) -> Result<(), E>, ) -> Result<(), InferenceError> { let beginning_of_sentence = self.n_past == 0; let prompt_tokens: Vec = vocab @@ -1262,7 +1278,7 @@ impl InferenceSession { vocab: &'v Vocabulary, params: &InferenceParameters, rng: &mut impl rand::Rng, - ) -> Result<&'v str, InferenceError> { + ) -> Result<&'v [u8], InferenceError> { if self.n_past + 1 >= model.hparams.n_ctx { return Err(InferenceError::ContextFull); } @@ -1303,15 +1319,19 @@ impl InferenceSession { prompt: &str, maximum_token_count: Option, rng: &mut impl rand::Rng, - callback: impl Fn(&str) -> Result<(), E>, + mut callback: impl FnMut(&str) -> Result<(), E>, ) -> Result { let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX); if params.play_back_previous_tokens { // "Play back" the existing tokens, so that loading from an inference snapshot works // as expected. + let mut token_utf8_buf = TokenUtf8Buffer::new(); for token_id in &self.tokens { - if let Err(e) = callback(vocab.token(*token_id as usize)) { - return Err(InferenceError::UserCallback(Box::new(e))); + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(vocab.token(*token_id as usize)) { + if let Err(e) = callback(&tokens) { + return Err(InferenceError::UserCallback(Box::new(e))); + } } } } @@ -1322,7 +1342,13 @@ impl InferenceSession { // Feed the initial prompt through the transformer, to update its // context window with new data. - self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?; + self.feed_prompt( + model, + vocab, + params, + prompt, + TokenUtf8Buffer::adapt_callback(&mut callback), + )?; stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; @@ -1331,6 +1357,7 @@ impl InferenceSession { // EndOfText token, or we run out of space in the context window, // or we reach the specified limit. let mut tokens_processed = 0; + let mut token_utf8_buf = TokenUtf8Buffer::new(); while tokens_processed < maximum_token_count { let token = match self.infer_next_token(model, vocab, params, rng) { Ok(token) => token, @@ -1338,8 +1365,11 @@ impl InferenceSession { Err(e) => return Err(e), }; - if let Err(e) = callback(token) { - return Err(InferenceError::UserCallback(Box::new(e))); + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(token) { + if let Err(e) = callback(&tokens) { + return Err(InferenceError::UserCallback(Box::new(e))); + } } tokens_processed += 1; @@ -1486,7 +1516,7 @@ impl Vocabulary { &'a self, text: &str, bos: bool, - ) -> Result, InferenceError> { + ) -> Result, InferenceError> { let len = text.len(); let mut score = vec![0usize; len + 1]; @@ -1496,7 +1526,6 @@ impl Vocabulary { let max_len = (len - i).min(self.max_token_length); for sub_len in 1..=max_len { let sub = &text.as_bytes()[i..i + sub_len]; - let Ok(sub) = std::str::from_utf8(sub) else { continue; }; let token = self.token_to_id.get(sub); if let Some(token) = token { @@ -1520,14 +1549,14 @@ impl Vocabulary { if token_id == 0 { return Err(InferenceError::TokenizationFailed); } - let token = self.id_to_token[token_id as usize].as_str(); + let token = self.id_to_token[token_id as usize].as_slice(); res.push((token, token_id)); i -= token.len(); } if bos { // TODO: replace with vocab.bos - res.push(("", 1)); + res.push((&[], 1)); } // Pieces are in reverse order so correct that @@ -1536,3 +1565,41 @@ impl Vocabulary { Ok(res) } } + +/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. +/// +/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 +/// from multiple tokens. This helps alleviate that issue. +#[derive(Clone, PartialEq, Default)] +pub struct TokenUtf8Buffer(Vec); +impl TokenUtf8Buffer { + /// Create a new buffer. + pub const fn new() -> Self { + Self(vec![]) + } + + /// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text, + /// it is returned and the buffer is cleared for next use. + pub fn push(&mut self, token: &[u8]) -> Option { + self.0.extend_from_slice(token); + match std::str::from_utf8(&self.0) { + Ok(s) => { + let out = s.to_owned(); + self.0 = vec![]; + Some(out) + } + Err(..) => None, + } + } + + /// Adapt a `&str` callback so that it can be used in a `&[u8]` context. + fn adapt_callback<'a, E: std::error::Error + 'static>( + mut callback: impl FnMut(&str) -> Result<(), E> + 'a, + ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a { + let mut buffer = Self::new(); + move |token| match buffer.push(token) { + Some(tokens) => callback(&tokens), + None => Ok(()), + } + } +} From c848d5ea12ea7f945bf1df80358d3898e06c3c4f Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Sat, 8 Apr 2023 11:50:40 +0000 Subject: [PATCH 12/42] Add standalone loader --- ggml/src/lib.rs | 3 +- llama-rs/src/convert.rs | 12 +- llama-rs/src/lib.rs | 62 +++++----- llama-rs/src/loader.rs | 20 ++-- llama-rs/src/loader2.rs | 260 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 305 insertions(+), 52 deletions(-) create mode 100644 llama-rs/src/loader2.rs diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 9128b354..915d5692 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -24,10 +24,11 @@ pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c; /// The currently-supported format version for `ggml` files. pub const FORMAT_VERSION: u32 = 1; -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] /// The type of a value in `ggml`. pub enum Type { /// Quantized 4-bit (type 0). + #[default] Q4_0, /// Quantized 4-bit (type 1); used by GPTQ. Q4_1, diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 1b7a909c..450150b5 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -16,7 +16,7 @@ use std::{ vec, }; -use crate::{util, Hyperparameters, Vocabulary}; +use crate::{util, Hyperparameters, Vocabulary, loader2::encode_element_type}; /// Converts a `pth` file to a `ggml` file. pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { @@ -82,13 +82,7 @@ fn load_hyperparameters( let json = read_to_string(path.join("params.json")).expect("Unable to read file"); let json: HyperParametersJson = serde_json::from_str(&json).expect("Unable to parse json"); Hyperparameters { - f16_: match element_type { - ggml::Type::F32 => 0, - ggml::Type::F16 => 1, - ggml::Type::Q4_0 => 2, - ggml::Type::Q4_1 => 3, - _ => panic!("unsupported element type"), - }, + element_type, n_ctx: 0, n_embd: json.dim, n_head: json.n_heads, @@ -112,7 +106,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String i32::try_from(hparams.n_head).unwrap(), i32::try_from(hparams.n_layer).unwrap(), i32::try_from(hparams.n_embd / hparams.n_head).unwrap(), - i32::try_from(hparams.f16_).unwrap(), + encode_element_type(hparams.element_type).unwrap(), ]; let mut packed_values: Vec = vec![]; diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index df169130..0965095f 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -4,6 +4,7 @@ #[cfg(feature = "convert")] pub mod convert; mod loader; +pub mod loader2; mod util; use core::slice; @@ -19,12 +20,13 @@ use std::{ pub use ggml::Type as ElementType; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; -use serde::Deserialize; use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; +use crate::loader2::decode_element_type; + /// dummy struct #[cfg(not(feature = "mmap"))] pub(crate) struct Mmap; @@ -45,7 +47,7 @@ impl Mmap { pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) /// The hyperparameters of the model. -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq)] pub struct Hyperparameters { n_vocab: usize, n_ctx: usize, @@ -54,7 +56,7 @@ pub struct Hyperparameters { n_head: usize, n_layer: usize, n_rot: usize, - f16_: u32, + element_type: ElementType, } struct Layer { @@ -74,12 +76,15 @@ struct Layer { w3: ggml::Tensor, } -/// Model Version +/// file type containing the model #[derive(Debug, PartialEq, Clone, Copy)] #[allow(clippy::upper_case_acronyms)] -pub(crate) enum ModelVersion { +pub enum ModelContainerType { + /// older than `GGJT` GGMF, + /// mmap-able format GGJT, + /// oldest ggml tensor file format Unversioned, } @@ -99,7 +104,7 @@ pub struct Model { mmap: Option, - _version: ModelVersion, + _version: ModelContainerType, // Must be kept alive for the model _context: ggml::Context, @@ -412,7 +417,7 @@ impl std::fmt::Display for TokenBias { /// Each variant represents a step within the process of loading the model. /// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +#[derive(Clone, PartialEq, Eq, Debug)] pub enum LoadProgress<'a> { /// The hyperparameters have been loaded from the model. HyperparametersLoaded(&'a Hyperparameters), @@ -484,6 +489,9 @@ pub enum LoadError { #[error("invalid integer conversion")] /// One of the integers encountered could not be converted to a more appropriate type. InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("unsupported f16_: {0}")] + /// One of the integers encountered could not be converted to a more appropriate type. + UnsupportedElementtype(i32), #[error("invalid magic number for {path:?}")] /// An invalid magic number was encountered during the loading process. InvalidMagic { @@ -657,10 +665,10 @@ impl Model { } // Verify magic - let model_type: ModelVersion = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF, - ggml::FILE_MAGIC_GGJT => ModelVersion::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelVersion::Unversioned, + let model_type: ModelContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -670,13 +678,13 @@ impl Model { // Load format version match model_type { - ModelVersion::GGMF | ModelVersion::GGJT => { + ModelContainerType::GGMF | ModelContainerType::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion { value: version }), }; } - ModelVersion::Unversioned => {} + ModelContainerType::Unversioned => {} } // ================= @@ -693,7 +701,10 @@ impl Model { n_head: read_i32(&mut reader)?.try_into()?, n_layer: read_i32(&mut reader)?.try_into()?, n_rot: read_i32(&mut reader)?.try_into()?, - f16_: read_i32(&mut reader)?.try_into()?, + element_type: { + let ftype = read_i32(&mut reader)?; + decode_element_type(ftype).ok_or_else(|| LoadError::UnsupportedElementtype(ftype)) + }?, }; let n_ff = @@ -719,12 +730,11 @@ impl Model { // Token score, currently unused match model_type { - ModelVersion::GGMF | ModelVersion::GGJT => { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); - } + ModelContainerType::GGMF | ModelContainerType::GGJT => { + let score = read_f32(&mut reader)?; + id_to_token_score.push(score); } - ModelVersion::Unversioned => { + ModelContainerType::Unversioned => { // Legacy model, set empty score id_to_token_score.push(0.); } @@ -742,13 +752,7 @@ impl Model { // for the big tensors, we have the option to store the data in 16-bit // floats or quantized in order to save memory and also to speed up the // computation - let wtype = match hparams.f16_ { - 0 => ggml::Type::F32, - 1 => ggml::Type::F16, - 2 => ggml::Type::Q4_0, - 3 => ggml::Type::Q4_1, - invalid => return Err(LoadError::HyperparametersF16Invalid { ftype: invalid }), - }; + let wtype = hparams.element_type; let n_embd = hparams.n_embd; let n_layer = hparams.n_layer; @@ -857,7 +861,7 @@ impl Model { }; match model_type { - ModelVersion::GGMF | ModelVersion::Unversioned => { + ModelContainerType::GGMF | ModelContainerType::Unversioned => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -867,7 +871,7 @@ impl Model { &model, )? } - ModelVersion::GGJT => { + ModelContainerType::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; let ptr = mmap.as_ptr(); model.mmap = Some(mmap); @@ -955,7 +959,7 @@ impl Model { n_head, n_layer, n_rot, - f16_: _, + element_type: _, } = self.hparams; // For the first run, we need to guess a maximum buffer size so we can measure diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index fc3b2c61..42aaad4d 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,4 +1,4 @@ -use crate::*; +use crate::{loader2::decode_element_type, *}; pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { let mut bytes = [0u8; N]; @@ -37,7 +37,7 @@ pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result Result { +pub(crate) fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } @@ -173,7 +173,7 @@ pub(crate) fn load_weights_ggmf_or_unversioned( byte_size: total_size, tensor_count: n_tensors.try_into()?, }); - }; + } Ok(()) } @@ -264,20 +264,14 @@ fn load_tensor_header_ggmf<'a>( } fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { - + let ftype = decode_element_type(ftype)?; match ftype { - 0 => Some(ggml::type_size(ggml::Type::F32)), - 1 => Some(ggml::type_size(ggml::Type::F16)), - 2 => { - assert_eq!(ne[0] % 64, 0); - Some(ggml::type_size(ggml::Type::Q4_0)) - } - 3 => { + ElementType::Q4_0 | ElementType::Q4_1 => { assert_eq!(ne[0] % 64, 0); - Some(ggml::type_size(ggml::Type::Q4_1)) } - _ => None, + _ => {} } + Some(ggml::type_size(ftype)) } pub(crate) fn load_weights_ggjt( diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs new file mode 100644 index 00000000..b7804216 --- /dev/null +++ b/llama-rs/src/loader2.rs @@ -0,0 +1,260 @@ +#![allow(missing_docs)] + +//! standalone model loader + +use std::{ + io::{BufRead, Seek, SeekFrom}, + ops::ControlFlow, +}; + +use crate::{loader::has_data_left, ElementType, ModelContainerType}; + +pub(crate) fn decode_element_type(ftype: i32) -> Option { + match ftype { + 0 => Some(ggml::Type::F32), + 1 => Some(ggml::Type::F16), + 2 => Some(ggml::Type::Q4_0), + 3 => Some(ggml::Type::Q4_1), + _ => None, + } +} + +pub(crate) fn encode_element_type(element_type: ElementType) -> Option { + match element_type { + ggml::Type::F32 => Some(0), + ggml::Type::F16 => Some(1), + ggml::Type::Q4_0 => Some(2), + ggml::Type::Q4_1 => Some(3), + _ => None, + } +} + +pub(crate) fn read_bytes( + reader: &mut impl BufRead, +) -> Result<[u8; N], std::io::Error> { + let mut bytes = [0u8; N]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_bytes_with_len( + reader: &mut impl BufRead, + len: usize, +) -> Result, std::io::Error> { + let mut bytes = vec![0u8; len]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +/// The hyperparameters of the model. +#[derive(Debug, Clone)] +pub struct FixedHyperparameters { + pub n_vocab: usize, + pub n_embd: usize, + pub n_mult: usize, + pub n_head: usize, + pub n_layer: usize, + pub n_rot: usize, + pub tensor_element_type: ElementType, +} + +#[derive(Debug, thiserror::Error)] +pub enum LoadError { + #[error("invalid file magic number: {0}")] + InvalidMagic(u32), + + #[error("invalid ggml format: version={0}")] + InvalidFormatVersion(u32), + + #[error("{0}")] + Io(#[from] std::io::Error), + + #[error("{0}")] + FailedCast(#[from] std::num::TryFromIntError), + + /// return `ControlFlow::Break` from any of the `cb_*` function to trigger this error + #[error("user requested interrupt: {0}")] + UserInterrupted(T), + + #[error("unsupported tensor dtype/f16_: {0}")] + UnsupportedElementtype(i32), + + /// sanity check failed + #[error("invariant broken: {0}")] + InvariantBroken(String), +} + +#[derive(Debug, Clone)] +pub struct TensorInfo { + pub name: Vec, + pub n_dims: usize, + pub dims: [usize; 2], + pub n_elements: usize, + pub ftype: ElementType, +} + +#[allow(unused_variables)] +pub trait LoadHandler { + fn cb_container_type(&mut self, model_type: ModelContainerType) -> ControlFlow { + ControlFlow::Continue(()) + } + + fn cb_hyper_parameters(&mut self, hparams: FixedHyperparameters) -> ControlFlow { + ControlFlow::Continue(()) + } + + fn cb_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { + ControlFlow::Continue(()) + } + + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; +} + +fn retchk(model_type: ControlFlow) -> Result> { + match model_type { + ControlFlow::Continue(x) => Ok(x), + ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), + } +} + +pub fn load_model_from_reader( + mut reader: impl BufRead + Seek, + handler: &mut impl LoadHandler, +) -> Result<(), LoadError> { + // Verify magic + let container_type: ModelContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned, + magic => return Err(LoadError::InvalidMagic(magic)), + }; + retchk(handler.cb_container_type(container_type))?; + + // Load format version + match container_type { + ModelContainerType::GGMF | ModelContainerType::GGJT => { + let _version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion(version)), + }; + } + ModelContainerType::Unversioned => {} + } + + // Load hyper params + // + // NOTE: Field order matters! Data is laid out in the file exactly + // in this order. + let hparams = FixedHyperparameters { + n_vocab: read_i32(&mut reader)?.try_into()?, + n_embd: read_i32(&mut reader)?.try_into()?, + n_mult: read_i32(&mut reader)?.try_into()?, + n_head: read_i32(&mut reader)?.try_into()?, + n_layer: read_i32(&mut reader)?.try_into()?, + n_rot: read_i32(&mut reader)?.try_into()?, + tensor_element_type: decode_element_type_res(read_i32(&mut reader)?)?, + }; + let n_vocab = hparams.n_vocab; + retchk(handler.cb_hyper_parameters(hparams))?; + + // Load vocabulary + for i in 0..n_vocab { + let len = read_u32(&mut reader)?.try_into()?; + let token = read_bytes_with_len(&mut reader, len)?; + let token_score = match container_type { + ModelContainerType::GGMF | ModelContainerType::GGJT => read_f32(&mut reader)?, + ModelContainerType::Unversioned => { + // Legacy model, set empty score + 0. + } + }; + retchk(handler.cb_vocab_token(i, token, token_score))?; + } + + // Load tensor data + match container_type { + ModelContainerType::GGMF | ModelContainerType::Unversioned => { + let _file_offset = reader.stream_position()?; + drop(reader); + todo!() + } + ModelContainerType::GGJT => load_weights_ggjt(&mut reader, handler), + } +} + +fn decode_element_type_res(ftype: i32) -> Result> { + match decode_element_type(ftype) { + Some(x) => Ok(x), + None => Err(LoadError::UnsupportedElementtype(ftype)), + } +} + +fn load_weights_ggjt( + reader: &mut (impl BufRead + Seek), + handler: &mut impl LoadHandler, +) -> Result<(), LoadError> { + while has_data_left(reader)? { + // load tensor header + let n_dims: usize = read_i32(reader)?.try_into()?; + let name_len = read_i32(reader)?; + let ftype = decode_element_type_res(read_i32(reader)?)?; + + let mut n_elements: usize = 1; + let mut dims = [1usize, 1]; + let ne_len = dims.len(); + if !(n_dims <= ne_len) { + return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}"))); + } + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + let dim: usize = read_i32(reader)?.try_into()?; + dims[i] = dim; + n_elements *= dim; + } + + // load tensor name + let name = read_bytes_with_len(reader, name_len.try_into()?)?; + + // sanity check + match ftype { + ElementType::Q4_0 | ElementType::Q4_1 => { + if !(dims[0] % 64 == 0) { + return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0"))); + } + } + _ => {} + } + + let tensor_info = TensorInfo { + name, dims, n_dims, n_elements, ftype, + }; + + // load tensor weights + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + reader.seek(SeekFrom::Start(offset_aligned))?; + + let type_size = ggml::type_size(ftype); + let buf = retchk(handler.tensor_buffer(tensor_info))?; + let buf_len = buf.len(); + if !(buf_len == type_size * n_elements) { + return Err(LoadError::InvariantBroken(format!("{buf_len} == {type_size} * {n_elements}"))); + } + reader.read_exact(buf)?; + } + + Ok(()) +} + From 83905936691b2637e381c06049dbb17f8c498966 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Sat, 8 Apr 2023 12:07:36 +0000 Subject: [PATCH 13/42] Move loader to standalone crate llama-loader --- Cargo.lock | 9 ++ Cargo.toml | 1 + ggml/src/lib.rs | 2 +- llama-loader/Cargo.toml | 10 ++ .../src/loader2.rs => llama-loader/src/lib.rs | 100 ++++++++---------- llama-loader/src/util.rs | 33 ++++++ llama-rs/Cargo.toml | 1 + llama-rs/src/convert.rs | 3 +- llama-rs/src/lib.rs | 41 +++---- llama-rs/src/loader.rs | 9 +- 10 files changed, 125 insertions(+), 84 deletions(-) create mode 100644 llama-loader/Cargo.toml rename llama-rs/src/loader2.rs => llama-loader/src/lib.rs (73%) create mode 100644 llama-loader/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index 38bd0498..6a31264a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,6 +533,14 @@ dependencies = [ "zstd", ] +[[package]] +name = "llama-loader" +version = "0.1.0" +dependencies = [ + "ggml", + "thiserror", +] + [[package]] name = "llama-rs" version = "0.1.0" @@ -540,6 +548,7 @@ dependencies = [ "bincode", "bytemuck", "ggml", + "llama-loader", "memmap2", "partial_sort", "protobuf", diff --git a/Cargo.toml b/Cargo.toml index 8ea220d8..4c383de8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "ggml-sys", "ggml", + "llama-loader", "llama-rs", "llama-cli", "generate-ggml-bindings" diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 915d5692..5aae51fa 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -440,7 +440,7 @@ impl Tensor { // /// Set the tensor's data pointer (useful for mmap-ed data) // /// // /// # Safety - // /// + // /// // /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. // pub unsafe fn set_data(&self, data_ptr: *mut c_void) { // self.with_alive_ctx(|| { diff --git a/llama-loader/Cargo.toml b/llama-loader/Cargo.toml new file mode 100644 index 00000000..cfc8d48b --- /dev/null +++ b/llama-loader/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "llama-loader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ggml = { path = "../ggml" } +thiserror = "*" diff --git a/llama-rs/src/loader2.rs b/llama-loader/src/lib.rs similarity index 73% rename from llama-rs/src/loader2.rs rename to llama-loader/src/lib.rs index b7804216..35c1e902 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-loader/src/lib.rs @@ -1,15 +1,31 @@ -#![allow(missing_docs)] +//! standalone model loader +//! +//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM. +#![allow(clippy::nonminimal_bool)] -//! standalone model loader +pub mod util; use std::{ io::{BufRead, Seek, SeekFrom}, ops::ControlFlow, }; +use util::*; + +pub type ElementType = ggml::Type; + +/// file type containing the model +#[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] +pub enum ContainerType { + /// legacy format, oldest ggml tensor file format + GGML, + /// also legacy format, newer than GGML, older than GGJT + GGMF, + /// mmap-able format + GGJT, +} -use crate::{loader::has_data_left, ElementType, ModelContainerType}; - -pub(crate) fn decode_element_type(ftype: i32) -> Option { +pub fn decode_element_type(ftype: i32) -> Option { match ftype { 0 => Some(ggml::Type::F32), 1 => Some(ggml::Type::F16), @@ -19,7 +35,7 @@ pub(crate) fn decode_element_type(ftype: i32) -> Option { } } -pub(crate) fn encode_element_type(element_type: ElementType) -> Option { +pub fn encode_element_type(element_type: ElementType) -> Option { match element_type { ggml::Type::F32 => Some(0), ggml::Type::F16 => Some(1), @@ -29,38 +45,9 @@ pub(crate) fn encode_element_type(element_type: ElementType) -> Option { } } -pub(crate) fn read_bytes( - reader: &mut impl BufRead, -) -> Result<[u8; N], std::io::Error> { - let mut bytes = [0u8; N]; - reader.read_exact(&mut bytes)?; - Ok(bytes) -} - -pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_bytes_with_len( - reader: &mut impl BufRead, - len: usize, -) -> Result, std::io::Error> { - let mut bytes = vec![0u8; len]; - reader.read_exact(&mut bytes)?; - Ok(bytes) -} - /// The hyperparameters of the model. #[derive(Debug, Clone)] -pub struct FixedHyperparameters { +pub struct LlamaHyperparameters { pub n_vocab: usize, pub n_embd: usize, pub n_mult: usize, @@ -90,7 +77,7 @@ pub enum LoadError { #[error("unsupported tensor dtype/f16_: {0}")] UnsupportedElementtype(i32), - + /// sanity check failed #[error("invariant broken: {0}")] InvariantBroken(String), @@ -107,11 +94,11 @@ pub struct TensorInfo { #[allow(unused_variables)] pub trait LoadHandler { - fn cb_container_type(&mut self, model_type: ModelContainerType) -> ControlFlow { + fn cb_container_type(&mut self, model_type: ContainerType) -> ControlFlow { ControlFlow::Continue(()) } - fn cb_hyper_parameters(&mut self, hparams: FixedHyperparameters) -> ControlFlow { + fn cb_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow { ControlFlow::Continue(()) } @@ -134,30 +121,30 @@ pub fn load_model_from_reader( handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic - let container_type: ModelContainerType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned, + let container_type: ContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, magic => return Err(LoadError::InvalidMagic(magic)), }; retchk(handler.cb_container_type(container_type))?; // Load format version match container_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => { + ContainerType::GGMF | ContainerType::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion(version)), }; } - ModelContainerType::Unversioned => {} + ContainerType::GGML => {} } // Load hyper params // // NOTE: Field order matters! Data is laid out in the file exactly // in this order. - let hparams = FixedHyperparameters { + let hparams = LlamaHyperparameters { n_vocab: read_i32(&mut reader)?.try_into()?, n_embd: read_i32(&mut reader)?.try_into()?, n_mult: read_i32(&mut reader)?.try_into()?, @@ -174,8 +161,8 @@ pub fn load_model_from_reader( let len = read_u32(&mut reader)?.try_into()?; let token = read_bytes_with_len(&mut reader, len)?; let token_score = match container_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => read_f32(&mut reader)?, - ModelContainerType::Unversioned => { + ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?, + ContainerType::GGML => { // Legacy model, set empty score 0. } @@ -185,12 +172,12 @@ pub fn load_model_from_reader( // Load tensor data match container_type { - ModelContainerType::GGMF | ModelContainerType::Unversioned => { + ContainerType::GGMF | ContainerType::GGML => { let _file_offset = reader.stream_position()?; drop(reader); todo!() } - ModelContainerType::GGJT => load_weights_ggjt(&mut reader, handler), + ContainerType::GGJT => load_weights_ggjt(&mut reader, handler), } } @@ -238,23 +225,28 @@ fn load_weights_ggjt( } let tensor_info = TensorInfo { - name, dims, n_dims, n_elements, ftype, + name, + dims, + n_dims, + n_elements, + ftype, }; // load tensor weights let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; reader.seek(SeekFrom::Start(offset_aligned))?; - + let type_size = ggml::type_size(ftype); let buf = retchk(handler.tensor_buffer(tensor_info))?; let buf_len = buf.len(); if !(buf_len == type_size * n_elements) { - return Err(LoadError::InvariantBroken(format!("{buf_len} == {type_size} * {n_elements}"))); + return Err(LoadError::InvariantBroken(format!( + "{buf_len} == {type_size} * {n_elements}" + ))); } reader.read_exact(buf)?; } Ok(()) } - diff --git a/llama-loader/src/util.rs b/llama-loader/src/util.rs new file mode 100644 index 00000000..06e5312f --- /dev/null +++ b/llama-loader/src/util.rs @@ -0,0 +1,33 @@ +use std::io::BufRead; + +pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { + let mut bytes = [0u8; N]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +pub fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_bytes_with_len( + reader: &mut impl BufRead, + len: usize, +) -> Result, std::io::Error> { + let mut bytes = vec![0u8; len]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +// NOTE: Implementation from #![feature(buf_read_has_data_left)] +pub fn has_data_left(reader: &mut impl BufRead) -> Result { + reader.fill_buf().map(|b| !b.is_empty()) +} diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 302a1389..d6d42087 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,6 +8,7 @@ rust-version = "1.65" [dependencies] ggml = { path = "../ggml" } +llama-loader = { path = "../llama-loader" } bytemuck = "1.13.1" partial_sort = "0.2.0" diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 450150b5..3a57b168 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -16,7 +16,8 @@ use std::{ vec, }; -use crate::{util, Hyperparameters, Vocabulary, loader2::encode_element_type}; +use crate::{util, Hyperparameters, Vocabulary}; +use llama_loader::encode_element_type; /// Converts a `pth` file to a `ggml` file. pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 0965095f..1ff7fa62 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -4,14 +4,13 @@ #[cfg(feature = "convert")] pub mod convert; mod loader; -pub mod loader2; mod util; use core::slice; use std::{ collections::HashMap, fmt::Display, - io::{BufRead, Read, Seek, SeekFrom}, + io::Seek, path::{Path, PathBuf}, str::FromStr, time, @@ -25,7 +24,7 @@ use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; -use crate::loader2::decode_element_type; +use llama_loader::{decode_element_type, ContainerType}; /// dummy struct #[cfg(not(feature = "mmap"))] @@ -76,18 +75,6 @@ struct Layer { w3: ggml::Tensor, } -/// file type containing the model -#[derive(Debug, PartialEq, Clone, Copy)] -#[allow(clippy::upper_case_acronyms)] -pub enum ModelContainerType { - /// older than `GGJT` - GGMF, - /// mmap-able format - GGJT, - /// oldest ggml tensor file format - Unversioned, -} - /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. pub struct Model { @@ -104,7 +91,7 @@ pub struct Model { mmap: Option, - _version: ModelContainerType, + _version: ContainerType, // Must be kept alive for the model _context: ggml::Context, @@ -665,10 +652,10 @@ impl Model { } // Verify magic - let model_type: ModelContainerType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned, + let model_type: ContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -678,13 +665,13 @@ impl Model { // Load format version match model_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => { + ContainerType::GGMF | ContainerType::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion { value: version }), }; } - ModelContainerType::Unversioned => {} + ContainerType::GGML => {} } // ================= @@ -723,18 +710,18 @@ impl Model { for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; - let token = read_bytes_with_len(&mut reader, len as usize)?; + let token = read_bytes_with_len(&mut reader, len)?; max_token_length = max_token_length.max(token.len()); id_to_token.push(token.clone()); token_to_id.insert(token, TokenId::try_from(i)?); // Token score, currently unused match model_type { - ModelContainerType::GGMF | ModelContainerType::GGJT => { + ContainerType::GGMF | ContainerType::GGJT => { let score = read_f32(&mut reader)?; id_to_token_score.push(score); } - ModelContainerType::Unversioned => { + ContainerType::GGML => { // Legacy model, set empty score id_to_token_score.push(0.); } @@ -861,7 +848,7 @@ impl Model { }; match model_type { - ModelContainerType::GGMF | ModelContainerType::Unversioned => { + ContainerType::GGMF | ContainerType::GGML => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -871,7 +858,7 @@ impl Model { &model, )? } - ModelContainerType::GGJT => { + ContainerType::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; let ptr = mmap.as_ptr(); model.mmap = Some(mmap); diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 42aaad4d..ad2b9ab0 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,4 +1,11 @@ -use crate::{loader2::decode_element_type, *}; +use std::{ + io::{BufRead, Read, Seek, SeekFrom}, + path::Path, +}; + +use crate::ElementType; +use crate::{util, LoadError, LoadProgress, Model}; +use llama_loader::decode_element_type; pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { let mut bytes = [0u8; N]; From 15fe19b854c590cf0d42747ea428eccba4426df0 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Sat, 8 Apr 2023 12:29:24 +0000 Subject: [PATCH 14/42] [llama-loader] Support non-copy loader --- llama-loader/src/lib.rs | 52 ++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/llama-loader/src/lib.rs b/llama-loader/src/lib.rs index 35c1e902..8cf4ed8d 100644 --- a/llama-loader/src/lib.rs +++ b/llama-loader/src/lib.rs @@ -90,23 +90,31 @@ pub struct TensorInfo { pub dims: [usize; 2], pub n_elements: usize, pub ftype: ElementType, + /// start of tensor - start of file + pub start_offset: u64, } #[allow(unused_variables)] pub trait LoadHandler { - fn cb_container_type(&mut self, model_type: ContainerType) -> ControlFlow { + fn got_container_type(&mut self, model_type: ContainerType) -> ControlFlow { ControlFlow::Continue(()) } - fn cb_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow { + fn got_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow { ControlFlow::Continue(()) } - fn cb_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { + fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { ControlFlow::Continue(()) } - fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; + /// # Returns + /// + /// `None` to skip copying + /// `Some(buf)` to provide a buffer for copying weights into + fn get_tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { + ControlFlow::Continue(None) + } } fn retchk(model_type: ControlFlow) -> Result> { @@ -127,7 +135,7 @@ pub fn load_model_from_reader( ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, magic => return Err(LoadError::InvalidMagic(magic)), }; - retchk(handler.cb_container_type(container_type))?; + retchk(handler.got_container_type(container_type))?; // Load format version match container_type { @@ -154,7 +162,7 @@ pub fn load_model_from_reader( tensor_element_type: decode_element_type_res(read_i32(&mut reader)?)?, }; let n_vocab = hparams.n_vocab; - retchk(handler.cb_hyper_parameters(hparams))?; + retchk(handler.got_hyper_parameters(hparams))?; // Load vocabulary for i in 0..n_vocab { @@ -167,7 +175,7 @@ pub fn load_model_from_reader( 0. } }; - retchk(handler.cb_vocab_token(i, token, token_score))?; + retchk(handler.got_vocab_token(i, token, token_score))?; } // Load tensor data @@ -224,28 +232,34 @@ fn load_weights_ggjt( _ => {} } + // load tensor weights + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + let tensor_info = TensorInfo { name, dims, n_dims, n_elements, ftype, + start_offset: offset_aligned }; - // load tensor weights - let offset_curr = reader.stream_position()?; - let offset_aligned: u64 = (offset_curr + 31) & !31; - reader.seek(SeekFrom::Start(offset_aligned))?; - + let type_size = ggml::type_size(ftype); - let buf = retchk(handler.tensor_buffer(tensor_info))?; - let buf_len = buf.len(); - if !(buf_len == type_size * n_elements) { - return Err(LoadError::InvariantBroken(format!( - "{buf_len} == {type_size} * {n_elements}" - ))); + if let Some(buf) = retchk(handler.get_tensor_buffer(tensor_info))? { + reader.seek(SeekFrom::Start(offset_aligned))?; + let buf_len = buf.len(); + if !(buf_len == type_size * n_elements) { + return Err(LoadError::InvariantBroken(format!( + "{buf_len} == {type_size} * {n_elements}" + ))); + } + reader.read_exact(buf)?; + } else { + // skip if no buffer is given + reader.seek(SeekFrom::Start(offset_aligned + (type_size * n_elements) as u64))?; } - reader.read_exact(buf)?; } Ok(()) From 2e9311dec430c07e1233e2dd860efd8cb0575269 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Sat, 8 Apr 2023 13:46:19 +0000 Subject: [PATCH 15/42] Use functions from the new crate --- llama-loader/src/lib.rs | 9 +++++---- llama-rs/src/lib.rs | 45 ++--------------------------------------- llama-rs/src/loader.rs | 29 +------------------------- 3 files changed, 8 insertions(+), 75 deletions(-) diff --git a/llama-loader/src/lib.rs b/llama-loader/src/lib.rs index 8cf4ed8d..3a0ab8f4 100644 --- a/llama-loader/src/lib.rs +++ b/llama-loader/src/lib.rs @@ -109,7 +109,7 @@ pub trait LoadHandler { } /// # Returns - /// + /// /// `None` to skip copying /// `Some(buf)` to provide a buffer for copying weights into fn get_tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { @@ -242,10 +242,9 @@ fn load_weights_ggjt( n_dims, n_elements, ftype, - start_offset: offset_aligned + start_offset: offset_aligned, }; - let type_size = ggml::type_size(ftype); if let Some(buf) = retchk(handler.get_tensor_buffer(tensor_info))? { reader.seek(SeekFrom::Start(offset_aligned))?; @@ -258,7 +257,9 @@ fn load_weights_ggjt( reader.read_exact(buf)?; } else { // skip if no buffer is given - reader.seek(SeekFrom::Start(offset_aligned + (type_size * n_elements) as u64))?; + reader.seek(SeekFrom::Start( + offset_aligned + (type_size * n_elements) as u64, + ))?; } } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 1ff7fa62..df5b2da0 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -24,6 +24,7 @@ use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; +use llama_loader::util::*; use llama_loader::{decode_element_type, ContainerType}; /// dummy struct @@ -609,48 +610,6 @@ impl Model { })?; let mut reader = BufReader::new(&file); - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_bytes_with_len( - reader: &mut impl BufRead, - len: usize, - ) -> Result, LoadError> { - let mut bytes = vec![0u8; len]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: len, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) - } - // Verify magic let model_type: ContainerType = match read_u32(&mut reader)? { ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, @@ -710,7 +669,7 @@ impl Model { for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; - let token = read_bytes_with_len(&mut reader, len)?; + let token = read_bytes_with_len(&mut reader, len.try_into()?)?; max_token_length = max_token_length.max(token.len()); id_to_token.push(token.clone()); token_to_id.insert(token, TokenId::try_from(i)?); diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index ad2b9ab0..b76be7fe 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -6,29 +6,7 @@ use std::{ use crate::ElementType; use crate::{util, LoadError, LoadProgress, Model}; use llama_loader::decode_element_type; - -pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) -} - -pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) -} +use llama_loader::util::*; /// Helper function. Reads a string from the buffer and returns it. pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { @@ -43,11 +21,6 @@ pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result Result { - reader.fill_buf().map(|b| !b.is_empty()) -} - pub(crate) fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path, From 34429e02d3ce766eaca779d7d08c497eeed03a1a Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Apr 2023 12:44:18 +0200 Subject: [PATCH 16/42] refactor(llama): pass mut tensors down --- ggml/src/lib.rs | 6 +++--- llama-rs/src/loader.rs | 31 ++++++++++++++++++------------- llama-rs/src/model.rs | 6 +++--- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 5e9d7185..303d8969 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -477,7 +477,7 @@ impl Tensor { /// # Safety /// /// Only `std::slice::from_raw_parts_mut(tensor.data(), tensor.nbytes())` is safe to mutate. - pub unsafe fn data(&self) -> *mut c_void { + pub unsafe fn data(&mut self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive unsafe { *self.ptr.as_ptr() }.data @@ -529,12 +529,12 @@ impl Tensor { /// # Safety /// /// This tensor must not be written to or read by from any other code. - pub unsafe fn write_data(&self, src: &[u8]) { + pub unsafe fn write_data(&mut self, src: &[u8]) { std::ptr::copy_nonoverlapping(src.as_ptr(), self.data() as *mut u8, src.len()) } /// Zeroes out this tensor. - pub fn zero_data(&self) { + pub fn zero_data(&mut self) { unsafe { std::ptr::write_bytes(self.data() as *mut u8, 0, self.nbytes()) } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 41073658..32991761 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -281,7 +281,6 @@ pub(crate) fn load( let context = ggml::Context::init(ctx_size); let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type); - match model_type { ContainerType::GGMF | ContainerType::GGML => { let file_offset = reader.stream_position()?; @@ -290,14 +289,20 @@ pub(crate) fn load( file_offset, main_path, load_progress_callback, - &model, + model.tensors_mut(), )? } ContainerType::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; let ptr = mmap.as_ptr(); model.mmap = Some(mmap); - load_weights_ggjt(&mut reader, ptr, main_path, load_progress_callback, &model)?; + load_weights_ggjt( + &mut reader, + ptr, + main_path, + load_progress_callback, + model.tensors_mut(), + )?; } } @@ -321,7 +326,7 @@ fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path, mut load_progress_callback: impl FnMut(LoadProgress), - model: &Model, + tensors: &mut HashMap, ) -> Result<(), LoadError> { use std::{fs::File, io::BufReader}; @@ -359,7 +364,7 @@ fn load_weights_ggmf_or_unversioned( n_dims, &mut part_reader, length, - model, + tensors, &part_path, n_parts, ftype, @@ -440,7 +445,7 @@ fn load_weights_ggmf_or_unversioned( load_progress_callback(LoadProgress::PartTensorLoaded { file: &part_path, current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors().len(), + tensor_count: tensors.len(), }); } @@ -458,11 +463,11 @@ fn load_tensor_header_ggmf<'a>( n_dims: usize, reader: &mut impl BufRead, length: i32, - model: &'a Model, + tensors: &'a mut HashMap, path: &Path, n_parts: usize, ftype: i32, -) -> Result<(usize, [i64; 2], String, &'a ggml::Tensor, i32, usize), LoadError> { +) -> Result<(usize, [i64; 2], String, &'a mut ggml::Tensor, i32, usize), LoadError> { let mut nelements = 1; let mut ne = [1i64, 1i64]; assert!(n_dims <= ne.len()); @@ -472,7 +477,7 @@ fn load_tensor_header_ggmf<'a>( nelements *= usize::try_from(ne[i])?; } let tensor_name = read_string(reader, length as usize)?; - let Some(tensor) = model.tensors().get(&tensor_name) + let Some(tensor) = tensors.get_mut(&tensor_name) else { return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); }; @@ -555,7 +560,7 @@ fn load_weights_ggjt( mmap_base: *const u8, path: &Path, mut load_progress_callback: impl FnMut(LoadProgress), - model: &Model, + tensors: &mut HashMap, ) -> Result<(), LoadError> // where R: std::io::Read { @@ -586,7 +591,7 @@ fn load_weights_ggjt( nelements *= dim; } let tensor_name = read_string(reader, length as usize)?; - let Some(tensor) = model.tensors().get(&tensor_name) + let Some(tensor) = tensors.get_mut(&tensor_name) else { return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); }; @@ -623,7 +628,7 @@ fn load_weights_ggjt( load_progress_callback(LoadProgress::PartTensorLoaded { file: path, current_tensor: loop_i, - tensor_count: model.tensors().len(), + tensor_count: tensors.len(), }); loop_i += 1; @@ -658,7 +663,7 @@ fn load_tensor_ggjt( fn load_tensor_ggjt<'a>( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, - tensor: &'a ggml::Tensor, + tensor: &'a mut ggml::Tensor, ) -> Result<(), LoadError> { _ = mmap_base; let offset_curr = reader.stream_position()?; diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index d5e11fed..95451b9a 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -187,7 +187,7 @@ impl Model { let mut gf = ggml::ComputationGraph::new(n_threads); - let embd = ctx0.new_tensor_1d(ggml::Type::I32, n); + let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, n); unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); @@ -432,8 +432,8 @@ impl Model { &self.vocabulary } - pub(crate) fn tensors(&self) -> &HashMap { - &self.tensors + pub(crate) fn tensors_mut(&mut self) -> &mut HashMap { + &mut self.tensors } } From 38e7d58242351eee7471a18f83c93537cdc7160e Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 14 Apr 2023 10:53:50 +0000 Subject: [PATCH 17/42] feat/loader Make hparams configurable --- llama-loader/src/lib.rs | 63 +++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/llama-loader/src/lib.rs b/llama-loader/src/lib.rs index 3a0ab8f4..86a7a00d 100644 --- a/llama-loader/src/lib.rs +++ b/llama-loader/src/lib.rs @@ -94,20 +94,41 @@ pub struct TensorInfo { pub start_offset: u64, } +/// Info in hyperparameter used for later loading tasks. Used in callback. +/// see [`LoadHandler::load_hyper_parameters`] +#[derive(Debug, Clone)] +pub struct PartialHyperparameters { + pub n_vocab: usize, +} + +/// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] +pub fn load_llama_hparams(reader: &mut R) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { + // NOTE: Field order matters! Data is laid out in the file exactly in this order. + let hparams = LlamaHyperparameters { + n_vocab: read_i32(reader)?.try_into()?, + n_embd: read_i32(reader)?.try_into()?, + n_mult: read_i32(reader)?.try_into()?, + n_head: read_i32(reader)?.try_into()?, + n_layer: read_i32(reader)?.try_into()?, + n_rot: read_i32(reader)?.try_into()?, + tensor_element_type: decode_element_type_res(read_i32(reader)?)?, + }; + let partial = PartialHyperparameters { n_vocab: hparams.n_vocab }; + Ok((hparams, partial)) +} + #[allow(unused_variables)] -pub trait LoadHandler { +pub trait LoadHandler { fn got_container_type(&mut self, model_type: ContainerType) -> ControlFlow { ControlFlow::Continue(()) } - fn got_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow { - ControlFlow::Continue(()) - } - fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { ControlFlow::Continue(()) } + fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow; + /// # Returns /// /// `None` to skip copying @@ -117,6 +138,12 @@ pub trait LoadHandler { } } +#[test] +fn can_be_vtable() { + use std::mem::MaybeUninit; + let _a: MaybeUninit>> = MaybeUninit::uninit(); +} + fn retchk(model_type: ControlFlow) -> Result> { match model_type { ControlFlow::Continue(x) => Ok(x), @@ -124,9 +151,9 @@ fn retchk(model_type: ControlFlow) -> Result> { } } -pub fn load_model_from_reader( - mut reader: impl BufRead + Seek, - handler: &mut impl LoadHandler, +pub fn load_model_from_reader( + mut reader: R, + handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic let container_type: ContainerType = match read_u32(&mut reader)? { @@ -149,20 +176,8 @@ pub fn load_model_from_reader( } // Load hyper params - // - // NOTE: Field order matters! Data is laid out in the file exactly - // in this order. - let hparams = LlamaHyperparameters { - n_vocab: read_i32(&mut reader)?.try_into()?, - n_embd: read_i32(&mut reader)?.try_into()?, - n_mult: read_i32(&mut reader)?.try_into()?, - n_head: read_i32(&mut reader)?.try_into()?, - n_layer: read_i32(&mut reader)?.try_into()?, - n_rot: read_i32(&mut reader)?.try_into()?, - tensor_element_type: decode_element_type_res(read_i32(&mut reader)?)?, - }; + let hparams = retchk(handler.load_hyper_parameters(&mut reader))?; let n_vocab = hparams.n_vocab; - retchk(handler.got_hyper_parameters(hparams))?; // Load vocabulary for i in 0..n_vocab { @@ -196,9 +211,9 @@ fn decode_element_type_res(ftype: i32) -> Result> { } } -fn load_weights_ggjt( - reader: &mut (impl BufRead + Seek), - handler: &mut impl LoadHandler, +fn load_weights_ggjt( + reader: &mut R, + handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { while has_data_left(reader)? { // load tensor header From 5dfc55d4a89d3e38fe1131d4d334d60a1553951b Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 14 Apr 2023 11:12:53 +0000 Subject: [PATCH 18/42] feat/loader Add hook to support multi-part model loading --- llama-loader/src/lib.rs | 63 ++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/llama-loader/src/lib.rs b/llama-loader/src/lib.rs index 86a7a00d..53a3e0ab 100644 --- a/llama-loader/src/lib.rs +++ b/llama-loader/src/lib.rs @@ -102,7 +102,9 @@ pub struct PartialHyperparameters { } /// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] -pub fn load_llama_hparams(reader: &mut R) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { +pub fn load_llama_hparams( + reader: &mut R, +) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { // NOTE: Field order matters! Data is laid out in the file exactly in this order. let hparams = LlamaHyperparameters { n_vocab: read_i32(reader)?.try_into()?, @@ -113,7 +115,9 @@ pub fn load_llama_hparams(reader: &mut R) -> Result<(Llama n_rot: read_i32(reader)?.try_into()?, tensor_element_type: decode_element_type_res(read_i32(reader)?)?, }; - let partial = PartialHyperparameters { n_vocab: hparams.n_vocab }; + let partial = PartialHyperparameters { + n_vocab: hparams.n_vocab, + }; Ok((hparams, partial)) } @@ -129,11 +133,19 @@ pub trait LoadHandler { fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow; + /// multi-file loading is not supported + /// To handle that yourself, return [`ControlFlow::Break(_)`] here + fn load_multipart(&mut self, reader: &mut R) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// callback to get tensor buffer to populate + /// /// # Returns /// /// `None` to skip copying /// `Some(buf)` to provide a buffer for copying weights into - fn get_tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { ControlFlow::Continue(None) } } @@ -152,11 +164,11 @@ fn retchk(model_type: ControlFlow) -> Result> { } pub fn load_model_from_reader( - mut reader: R, + reader: &mut R, handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic - let container_type: ContainerType = match read_u32(&mut reader)? { + let container_type: ContainerType = match read_u32(reader)? { ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, @@ -167,7 +179,7 @@ pub fn load_model_from_reader( // Load format version match container_type { ContainerType::GGMF | ContainerType::GGJT => { - let _version: u32 = match read_u32(&mut reader)? { + let _version: u32 = match read_u32(reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion(version)), }; @@ -176,15 +188,15 @@ pub fn load_model_from_reader( } // Load hyper params - let hparams = retchk(handler.load_hyper_parameters(&mut reader))?; + let hparams = retchk(handler.load_hyper_parameters(reader))?; let n_vocab = hparams.n_vocab; // Load vocabulary for i in 0..n_vocab { - let len = read_u32(&mut reader)?.try_into()?; - let token = read_bytes_with_len(&mut reader, len)?; + let len = read_u32(reader)?.try_into()?; + let token = read_bytes_with_len(reader, len)?; let token_score = match container_type { - ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?, + ContainerType::GGMF | ContainerType::GGJT => read_f32(reader)?, ContainerType::GGML => { // Legacy model, set empty score 0. @@ -196,11 +208,10 @@ pub fn load_model_from_reader( // Load tensor data match container_type { ContainerType::GGMF | ContainerType::GGML => { - let _file_offset = reader.stream_position()?; - drop(reader); - todo!() + retchk(handler.load_multipart(reader))?; + load_weights(reader, handler, false) } - ContainerType::GGJT => load_weights_ggjt(&mut reader, handler), + ContainerType::GGJT => load_weights_ggjt(reader, handler), } } @@ -214,6 +225,18 @@ fn decode_element_type_res(ftype: i32) -> Result> { fn load_weights_ggjt( reader: &mut R, handler: &mut impl LoadHandler, +) -> Result<(), LoadError> { + load_weights(reader, handler, true) +} + +/// # Params +/// +/// `align` +/// align to 4 bytes before reading tensor weights +fn load_weights( + reader: &mut R, + handler: &mut impl LoadHandler, + align: bool, ) -> Result<(), LoadError> { while has_data_left(reader)? { // load tensor header @@ -249,7 +272,11 @@ fn load_weights_ggjt( // load tensor weights let offset_curr = reader.stream_position()?; - let offset_aligned: u64 = (offset_curr + 31) & !31; + let offset_aligned: u64 = if align { + (offset_curr + 31) & !31 + } else { + offset_curr + }; let tensor_info = TensorInfo { name, @@ -261,8 +288,10 @@ fn load_weights_ggjt( }; let type_size = ggml::type_size(ftype); - if let Some(buf) = retchk(handler.get_tensor_buffer(tensor_info))? { - reader.seek(SeekFrom::Start(offset_aligned))?; + if let Some(buf) = retchk(handler.tensor_buffer(tensor_info))? { + if align { + reader.seek(SeekFrom::Start(offset_aligned))?; + } let buf_len = buf.len(); if !(buf_len == type_size * n_elements) { return Err(LoadError::InvariantBroken(format!( From 48efd74114f7760d507e636b02ad16d2a6bb9ca2 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 14 Apr 2023 11:22:19 +0000 Subject: [PATCH 19/42] rename llama-loader to ggml-loader --- Cargo.lock | 18 +++--- Cargo.toml | 2 +- {llama-loader => ggml-loader}/Cargo.toml | 2 +- {llama-loader => ggml-loader}/src/lib.rs | 68 +---------------------- {llama-loader => ggml-loader}/src/util.rs | 31 ++++++++++- llama-rs/Cargo.toml | 2 +- llama-rs/src/convert.rs | 2 +- llama-rs/src/lib.rs | 1 + llama-rs/src/loader.rs | 4 +- llama-rs/src/loader2.rs | 36 ++++++++++++ llama-rs/src/model.rs | 2 +- 11 files changed, 86 insertions(+), 82 deletions(-) rename {llama-loader => ggml-loader}/Cargo.toml (90%) rename {llama-loader => ggml-loader}/src/lib.rs (78%) rename {llama-loader => ggml-loader}/src/util.rs (54%) create mode 100644 llama-rs/src/loader2.rs diff --git a/Cargo.lock b/Cargo.lock index bb0ba9ba..6bb4e48e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,6 +386,14 @@ dependencies = [ "ggml-sys", ] +[[package]] +name = "ggml-loader" +version = "0.1.0" +dependencies = [ + "ggml", + "thiserror", +] + [[package]] name = "ggml-sys" version = "0.1.0" @@ -534,21 +542,13 @@ dependencies = [ "zstd", ] -[[package]] -name = "llama-loader" -version = "0.1.0" -dependencies = [ - "ggml", - "thiserror", -] - [[package]] name = "llama-rs" version = "0.1.0" dependencies = [ "bytemuck", "ggml", - "llama-loader", + "ggml-loader", "memmap2", "partial_sort", "protobuf", diff --git a/Cargo.toml b/Cargo.toml index 4c383de8..f579b1c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = [ "ggml-sys", "ggml", - "llama-loader", + "ggml-loader", "llama-rs", "llama-cli", "generate-ggml-bindings" diff --git a/llama-loader/Cargo.toml b/ggml-loader/Cargo.toml similarity index 90% rename from llama-loader/Cargo.toml rename to ggml-loader/Cargo.toml index cfc8d48b..ab711363 100644 --- a/llama-loader/Cargo.toml +++ b/ggml-loader/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "llama-loader" +name = "ggml-loader" version = "0.1.0" edition = "2021" diff --git a/llama-loader/src/lib.rs b/ggml-loader/src/lib.rs similarity index 78% rename from llama-loader/src/lib.rs rename to ggml-loader/src/lib.rs index 53a3e0ab..1d4c8f4c 100644 --- a/llama-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -5,10 +5,7 @@ pub mod util; -use std::{ - io::{BufRead, Seek, SeekFrom}, - ops::ControlFlow, -}; +use std::ops::ControlFlow; use util::*; pub type ElementType = ggml::Type; @@ -25,38 +22,6 @@ pub enum ContainerType { GGJT, } -pub fn decode_element_type(ftype: i32) -> Option { - match ftype { - 0 => Some(ggml::Type::F32), - 1 => Some(ggml::Type::F16), - 2 => Some(ggml::Type::Q4_0), - 3 => Some(ggml::Type::Q4_1), - _ => None, - } -} - -pub fn encode_element_type(element_type: ElementType) -> Option { - match element_type { - ggml::Type::F32 => Some(0), - ggml::Type::F16 => Some(1), - ggml::Type::Q4_0 => Some(2), - ggml::Type::Q4_1 => Some(3), - _ => None, - } -} - -/// The hyperparameters of the model. -#[derive(Debug, Clone)] -pub struct LlamaHyperparameters { - pub n_vocab: usize, - pub n_embd: usize, - pub n_mult: usize, - pub n_head: usize, - pub n_layer: usize, - pub n_rot: usize, - pub tensor_element_type: ElementType, -} - #[derive(Debug, thiserror::Error)] pub enum LoadError { #[error("invalid file magic number: {0}")] @@ -101,26 +66,6 @@ pub struct PartialHyperparameters { pub n_vocab: usize, } -/// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] -pub fn load_llama_hparams( - reader: &mut R, -) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { - // NOTE: Field order matters! Data is laid out in the file exactly in this order. - let hparams = LlamaHyperparameters { - n_vocab: read_i32(reader)?.try_into()?, - n_embd: read_i32(reader)?.try_into()?, - n_mult: read_i32(reader)?.try_into()?, - n_head: read_i32(reader)?.try_into()?, - n_layer: read_i32(reader)?.try_into()?, - n_rot: read_i32(reader)?.try_into()?, - tensor_element_type: decode_element_type_res(read_i32(reader)?)?, - }; - let partial = PartialHyperparameters { - n_vocab: hparams.n_vocab, - }; - Ok((hparams, partial)) -} - #[allow(unused_variables)] pub trait LoadHandler { fn got_container_type(&mut self, model_type: ContainerType) -> ControlFlow { @@ -215,14 +160,7 @@ pub fn load_model_from_reader( } } -fn decode_element_type_res(ftype: i32) -> Result> { - match decode_element_type(ftype) { - Some(x) => Ok(x), - None => Err(LoadError::UnsupportedElementtype(ftype)), - } -} - -fn load_weights_ggjt( +pub fn load_weights_ggjt( reader: &mut R, handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { @@ -233,7 +171,7 @@ fn load_weights_ggjt( /// /// `align` /// align to 4 bytes before reading tensor weights -fn load_weights( +pub fn load_weights( reader: &mut R, handler: &mut impl LoadHandler, align: bool, diff --git a/llama-loader/src/util.rs b/ggml-loader/src/util.rs similarity index 54% rename from llama-loader/src/util.rs rename to ggml-loader/src/util.rs index 06e5312f..b0864dc5 100644 --- a/llama-loader/src/util.rs +++ b/ggml-loader/src/util.rs @@ -1,4 +1,6 @@ -use std::io::BufRead; +pub use std::io::{BufRead, Seek, SeekFrom}; + +use crate::{ElementType, LoadError}; pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { let mut bytes = [0u8; N]; @@ -31,3 +33,30 @@ pub fn read_bytes_with_len( pub fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } + +pub fn decode_element_type(ftype: i32) -> Option { + match ftype { + 0 => Some(ggml::Type::F32), + 1 => Some(ggml::Type::F16), + 2 => Some(ggml::Type::Q4_0), + 3 => Some(ggml::Type::Q4_1), + _ => None, + } +} + +pub fn encode_element_type(element_type: ElementType) -> Option { + match element_type { + ggml::Type::F32 => Some(0), + ggml::Type::F16 => Some(1), + ggml::Type::Q4_0 => Some(2), + ggml::Type::Q4_1 => Some(3), + _ => None, + } +} + +pub fn decode_element_type_res(ftype: i32) -> Result> { + match decode_element_type(ftype) { + Some(x) => Ok(x), + None => Err(LoadError::UnsupportedElementtype(ftype)), + } +} diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index a064520e..4e854a43 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,7 +8,7 @@ rust-version = "1.65" [dependencies] ggml = { path = "../ggml" } -llama-loader = { path = "../llama-loader" } +ggml-loader = { path = "../ggml-loader" } rand = { workspace = true } diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 9be1f5ed..67557b8f 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -17,7 +17,7 @@ use std::{ }; use crate::{util, Hyperparameters, Vocabulary}; -use llama_loader::encode_element_type; +use ggml_loader::util::encode_element_type; /// Converts a `pth` file to a `ggml` file. pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 941e3159..db273d6f 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -8,6 +8,7 @@ pub mod convert; mod inference_session; mod loader; +mod loader2; mod model; mod util; mod vocabulary; diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 32991761..2575597f 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -9,8 +9,8 @@ use crate::{ Mmap, Model, TokenId, Vocabulary, }; use crate::{ElementType, Hyperparameters}; -use llama_loader::util::*; -use llama_loader::{decode_element_type, ContainerType}; +use ggml_loader::util::*; +use ggml_loader::ContainerType; use thiserror::Error; /// Each variant represents a step within the process of loading the model. diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs new file mode 100644 index 00000000..f2a8156f --- /dev/null +++ b/llama-rs/src/loader2.rs @@ -0,0 +1,36 @@ +//! ggml-loader aux + +use ggml_loader::*; +use ggml_loader::util::*; + +/// The hyperparameters of the model. +#[derive(Debug, Clone)] +pub struct LlamaHyperparameters { + pub n_vocab: usize, + pub n_embd: usize, + pub n_mult: usize, + pub n_head: usize, + pub n_layer: usize, + pub n_rot: usize, + pub tensor_element_type: ElementType, +} + +/// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] +pub fn load_llama_hparams( + reader: &mut R, +) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { + // NOTE: Field order matters! Data is laid out in the file exactly in this order. + let hparams = LlamaHyperparameters { + n_vocab: read_i32(reader)?.try_into()?, + n_embd: read_i32(reader)?.try_into()?, + n_mult: read_i32(reader)?.try_into()?, + n_head: read_i32(reader)?.try_into()?, + n_layer: read_i32(reader)?.try_into()?, + n_rot: read_i32(reader)?.try_into()?, + tensor_element_type: decode_element_type_res(read_i32(reader)?)?, + }; + let partial = PartialHyperparameters { + n_vocab: hparams.n_vocab, + }; + Ok((hparams, partial)) +} diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index 95451b9a..c08f1256 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -5,7 +5,7 @@ use crate::{ InferenceSessionParameters, LoadError, LoadProgress, Mmap, Vocabulary, }; -use llama_loader::ContainerType; +use ggml_loader::ContainerType; /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. From d65996d13e0d23af62b08b7d949d165f254c2100 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Wed, 12 Apr 2023 14:59:46 +0800 Subject: [PATCH 20/42] fix --- ggml/src/lib.rs | 31 ++++++++++++++++++++----------- llama-cli/Cargo.toml | 5 ++++- llama-cli/src/cli_args.rs | 7 ++++++- llama-rs/src/loader.rs | 4 ++-- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 87aac15a..93956456 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -526,6 +526,14 @@ impl Tensor { } } + fn with_alive_ctx_mut(&self, mut f: impl FnMut() -> U) -> U { + if let Some(_ctx) = self.ctx.upgrade() { + f() + } else { + panic!("Using a tensor after the context was dropped") + } + } + /// Number of bytes used by this tensor. pub fn nbytes(&self) -> usize { self.with_alive_ctx(|| { @@ -546,17 +554,18 @@ impl Tensor { }) } - // /// Set the tensor's data pointer (useful for mmap-ed data) - // /// - // /// # Safety - // /// - // /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. - // pub unsafe fn set_data(&self, data_ptr: *mut c_void) { - // self.with_alive_ctx(|| { - // // SAFETY: The with_alive_call guarantees the context is alive - // unsafe { *self.ptr.as_ptr() }.data = data_ptr; - // }) - // } + /// Set the tensor's data pointer (useful for mmap-ed data) + /// + /// # Safety + /// + /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. + pub unsafe fn set_data(&mut self, data_ptr: *mut c_void) { + let tensor = self.ptr.as_mut(); + self.with_alive_ctx_mut(|| { + // SAFETY: The with_alive_call guarantees the context is alive + tensor.data = data_ptr; + }) + } /// Number of elements in this tensor. pub fn nelements(&self) -> usize { diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 2eff43b7..72727e0f 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -18,4 +18,7 @@ num_cpus = "1.15.0" once_cell = "1.17.1" rustyline = "11.0.0" spinners = "4.1.0" -zstd = { version = "0.12", default-features = false } \ No newline at end of file +zstd = { version = "0.12", default-features = false } + +[features] +mmap = ["llama-rs/mmap"] \ No newline at end of file diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index ff9556c5..823394f2 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -232,6 +232,7 @@ impl Generate { } }), play_back_previous_tokens: session_loaded, + ..Default::default() } } } @@ -261,6 +262,7 @@ pub struct ModelLoad { } impl ModelLoad { pub fn load(&self) -> llama_rs::Model { + let now = std::time::Instant::now(); let model = llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| { use llama_rs::LoadProgress; match progress { @@ -310,7 +312,10 @@ impl ModelLoad { }) .expect("Could not load model"); - log::info!("Model fully loaded!"); + log::info!( + "Model fully loaded! Elapsed: {}ms", + now.elapsed().as_millis() + ); model } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 2575597f..c500a323 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -647,7 +647,7 @@ fn load_weights_ggjt( fn load_tensor_ggjt( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, - tensor: &ggml::Tensor, + tensor: &mut ggml::Tensor, ) -> Result<(), LoadError> { let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; @@ -655,7 +655,7 @@ fn load_tensor_ggjt( let ptr = mmap_base.offset(offset_aligned as isize); tensor.set_data(ptr as *mut std::ffi::c_void); } - reader.seek(SeekFrom::Start(offset_aligned + tensor.nbytes() as u8))?; + reader.seek(SeekFrom::Start(offset_aligned + tensor.nbytes() as u64))?; Ok(()) } From 267d8ae99cc8f6008bb040afe25ca3bf781d47a5 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Thu, 13 Apr 2023 00:18:37 +0800 Subject: [PATCH 21/42] no_alloc --- ggml/src/lib.rs | 4 ++-- llama-rs/src/inference_session.rs | 8 +++---- llama-rs/src/loader.rs | 36 +++++++++++++++++-------------- llama-rs/src/model.rs | 2 +- llama-rs/src/util.rs | 2 +- 5 files changed, 28 insertions(+), 24 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 93956456..11d4246b 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -86,14 +86,14 @@ pub struct Context { } impl Context { /// Creates a new [Context] with the specified `mem_size` as a working area. - pub fn init(mem_size: usize) -> Self { + pub fn init(mem_size: usize, alloc: bool) -> Self { let raw = unsafe { ggml_sys::ggml_init(ggml_sys::ggml_init_params { mem_size, // Null here means we want ggml to own this memory. We don't // support passing an owned buffer from the Rust side. mem_buffer: std::ptr::null_mut(), - no_alloc: false, + no_alloc: !alloc, }) }; Self { diff --git a/llama-rs/src/inference_session.rs b/llama-rs/src/inference_session.rs index 3af27812..90179df2 100644 --- a/llama-rs/src/inference_session.rs +++ b/llama-rs/src/inference_session.rs @@ -365,7 +365,7 @@ impl InferenceSession { ctx_size }; - let session_ctx = ggml::Context::init(ctx_size); + let session_ctx = ggml::Context::init(ctx_size, true); // Initialize key + value memory tensors let n_mem = n_layer * n_ctx; @@ -397,7 +397,7 @@ impl InferenceSession { } impl Clone for InferenceSession { fn clone(&self) -> Self { - let context = ggml::Context::init(self.memory_size); + let context = ggml::Context::init(self.memory_size, true); let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements()); let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements()); @@ -493,7 +493,7 @@ pub struct InferenceSnapshot { pub memory_v: Vec, } -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] /// Parameters for an inference session. pub struct InferenceSessionParameters { /// The number of tokens to consider for the repetition penalty. @@ -550,7 +550,7 @@ impl Display for InferenceStats { } /// Allowed types for the model memory K/V tensors. -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum ModelKVMemoryType { /// 16-bit float. Float16, diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index c500a323..131e851a 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -247,30 +247,34 @@ pub(crate) fn load( let n_layer = hparams.n_layer; let n_vocab = hparams.n_vocab; + let alloc = !(cfg!(feature = "mmap") && model_type == ContainerType::GGJT); + let ctx_size = { // Use 64-bit math to prevent overflow. - let mut ctx_size: usize = 0; - - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + let mut ctx_size: usize = (5 + 10 * n_layer) * 256; // object overhead - ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + if alloc { + let mut model_size: usize = 0; - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 - ctx_size += (5 + 10 * n_layer) * 256; // object overhead + ctx_size += model_size; + } load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size }); @@ -278,7 +282,7 @@ pub(crate) fn load( }; // Initialize the context - let context = ggml::Context::init(ctx_size); + let context = ggml::Context::init(ctx_size, alloc); let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type); match model_type { diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index c08f1256..b4cfa017 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -183,7 +183,7 @@ impl Model { // add 10% to account for ggml object overhead buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; }; - let ctx0 = ggml::Context::init(buf_size); + let ctx0 = ggml::Context::init(buf_size, true); let mut gf = ggml::ComputationGraph::new(n_threads); diff --git a/llama-rs/src/util.rs b/llama-rs/src/util.rs index 3eb8f06d..2f4b702a 100644 --- a/llama-rs/src/util.rs +++ b/llama-rs/src/util.rs @@ -21,7 +21,7 @@ pub(crate) use mulf; /// /// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 /// from multiple tokens. This helps alleviate that issue. -#[derive(Clone, PartialEq, Default)] +#[derive(Clone, PartialEq, Eq, Default)] pub struct TokenUtf8Buffer(Vec); impl TokenUtf8Buffer { /// Create a new buffer. From 81a69796b560c0a440a901d0f4d80b2977899ddc Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 19 Apr 2023 23:06:09 +0200 Subject: [PATCH 22/42] chore: fix clippy --- llama-cli/src/cli_args.rs | 1 - llama-rs/src/loader2.rs | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 823394f2..608b8647 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -232,7 +232,6 @@ impl Generate { } }), play_back_previous_tokens: session_loaded, - ..Default::default() } } } diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index f2a8156f..757528fd 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -1,7 +1,7 @@ //! ggml-loader aux -use ggml_loader::*; use ggml_loader::util::*; +use ggml_loader::*; /// The hyperparameters of the model. #[derive(Debug, Clone)] @@ -16,6 +16,7 @@ pub struct LlamaHyperparameters { } /// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] +#[allow(dead_code)] pub fn load_llama_hparams( reader: &mut R, ) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { From 80d189e31d8a6f8340b3e6652ae798bc3bad85b4 Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 19 Apr 2023 23:15:38 +0200 Subject: [PATCH 23/42] refactor(util): make find_all_model_files error --- llama-rs/src/lib.rs | 2 +- llama-rs/src/loader.rs | 8 ++++++++ llama-rs/src/util.rs | 25 ++++++++++++++++++++++--- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index db273d6f..fea3ddea 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -33,7 +33,7 @@ pub(crate) struct Mmap; /// dummy impl #[cfg(not(feature = "mmap"))] impl Mmap { - pub(crate) unsafe fn map(_: &std::fs::File) -> Result { + pub(crate) unsafe fn map(_: &std::fs::File) -> std::io::Result { Ok(Mmap) } pub(crate) fn as_ptr(&self) -> *const u8 { diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 131e851a..5f96c629 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -136,6 +136,14 @@ pub enum LoadError { path: PathBuf, }, } +impl From for LoadError { + fn from(value: util::FindAllModelFilesError) -> Self { + match value { + util::FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path }, + util::FindAllModelFilesError::IO(err) => LoadError::IO(err), + } + } +} pub(crate) fn load( path: impl AsRef, diff --git a/llama-rs/src/util.rs b/llama-rs/src/util.rs index 2f4b702a..3044bdb6 100644 --- a/llama-rs/src/util.rs +++ b/llama-rs/src/util.rs @@ -16,6 +16,7 @@ macro_rules! mulf { } pub(crate) use mulf; +use thiserror::Error; /// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. /// @@ -69,11 +70,29 @@ impl TokenUtf8Buffer { } } -pub(crate) fn find_all_model_files(main_path: &Path) -> Result, LoadError> { +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum FindAllModelFilesError { + #[error("no parent path for {path:?}")] + /// There is no parent path for a given path. + NoParentPath { + /// The path without a parent. + path: PathBuf, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + IO(#[from] std::io::Error), +} + +pub(crate) fn find_all_model_files( + main_path: &Path, +) -> Result, FindAllModelFilesError> { Ok(collect_related_paths( main_path, - std::fs::read_dir(main_path.parent().ok_or_else(|| LoadError::NoParentPath { - path: main_path.to_owned(), + std::fs::read_dir(main_path.parent().ok_or_else(|| { + FindAllModelFilesError::NoParentPath { + path: main_path.to_owned(), + } })?)? .filter_map(Result::ok) .map(|de| de.path()), From 85e1148218f33203e422769403e12e859058e860 Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 19 Apr 2023 23:28:52 +0200 Subject: [PATCH 24/42] UnsupportedElementtype -> UnsupportedElementType --- ggml-loader/src/lib.rs | 2 +- ggml-loader/src/util.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index 1d4c8f4c..b490868a 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -41,7 +41,7 @@ pub enum LoadError { UserInterrupted(T), #[error("unsupported tensor dtype/f16_: {0}")] - UnsupportedElementtype(i32), + UnsupportedElementType(i32), /// sanity check failed #[error("invariant broken: {0}")] diff --git a/ggml-loader/src/util.rs b/ggml-loader/src/util.rs index b0864dc5..342f3bd0 100644 --- a/ggml-loader/src/util.rs +++ b/ggml-loader/src/util.rs @@ -57,6 +57,6 @@ pub fn encode_element_type(element_type: ElementType) -> Option { pub fn decode_element_type_res(ftype: i32) -> Result> { match decode_element_type(ftype) { Some(x) => Ok(x), - None => Err(LoadError::UnsupportedElementtype(ftype)), + None => Err(LoadError::UnsupportedElementType(ftype)), } } From 3f29992c6841b210d8e5a632a589a294c84912e0 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 20 Apr 2023 00:49:26 +0200 Subject: [PATCH 25/42] feat: experimental loader2 wire-up (incomplete) --- llama-rs/src/lib.rs | 3 +- llama-rs/src/loader.rs | 141 +---------------- llama-rs/src/loader2.rs | 283 ++++++++++++++++++++++++++++++++-- llama-rs/src/loader_common.rs | 168 ++++++++++++++++++++ llama-rs/src/model.rs | 14 +- llama-rs/src/util.rs | 2 - llama-rs/src/vocabulary.rs | 2 +- 7 files changed, 453 insertions(+), 160 deletions(-) create mode 100644 llama-rs/src/loader_common.rs diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index fea3ddea..dc91f864 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -9,6 +9,7 @@ pub mod convert; mod inference_session; mod loader; mod loader2; +mod loader_common; mod model; mod util; mod vocabulary; @@ -18,7 +19,7 @@ pub use inference_session::{ InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType, SnapshotError, }; -pub use loader::{LoadError, LoadProgress}; +pub use loader_common::{LoadError, LoadProgress, UnexpectedState}; pub use model::{Hyperparameters, Model}; pub use util::TokenUtf8Buffer; pub use vocabulary::{TokenBias, TokenId, Vocabulary}; diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 5f96c629..2b2ad656 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,149 +1,18 @@ +#![allow(dead_code)] + use std::{ collections::HashMap, io::{BufRead, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, + path::Path, }; use crate::{ util::{self, mulf}, - Mmap, Model, TokenId, Vocabulary, + LoadError, LoadProgress, Mmap, Model, TokenId, Vocabulary, }; use crate::{ElementType, Hyperparameters}; use ggml_loader::util::*; use ggml_loader::ContainerType; -use thiserror::Error; - -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, Debug)] -pub enum LoadProgress<'a> { - /// The hyperparameters have been loaded from the model. - HyperparametersLoaded(&'a Hyperparameters), - /// The context has been created. - ContextSize { - /// The size of the context. - bytes: usize, - }, - /// A part of the model is being loaded. - PartLoading { - /// The path to the model part. - file: &'a Path, - /// The current part (0-indexed). - current_part: usize, - /// The number of total parts. - total_parts: usize, - }, - /// A tensor from the current part has been loaded. - PartTensorLoaded { - /// The path to the model part. - file: &'a Path, - /// The current tensor (0-indexed). - current_tensor: usize, - /// The number of total tensors. - tensor_count: usize, - }, - /// A model part has finished fully loading. - PartLoaded { - /// The path to the model part. - file: &'a Path, - /// The number of bytes in the part. - byte_size: usize, - /// The number of tensors in the part. - tensor_count: usize, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum LoadError { - #[error("could not open file {path:?}")] - /// A file failed to open. - OpenFileFailed { - /// The original error. - source: std::io::Error, - /// The path that failed. - path: PathBuf, - }, - #[error("no parent path for {path:?}")] - /// There is no parent path for a given path. - NoParentPath { - /// The path without a parent. - path: PathBuf, - }, - #[error("unable to read exactly {bytes} bytes")] - /// Reading exactly `bytes` from a file failed. - ReadExactFailed { - /// The original error. - source: std::io::Error, - /// The number of bytes that were attempted to be read. - bytes: usize, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - IO(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("unsupported f16_: {0}")] - /// One of the integers encountered could not be converted to a more appropriate type. - UnsupportedElementType(i32), - #[error("invalid magic number for {path:?}")] - /// An invalid magic number was encountered during the loading process. - InvalidMagic { - /// The path that failed. - path: PathBuf, - }, - #[error("invalid file format version {value}")] - /// The version of the format is not supported by this version of `llama-rs`. - InvalidFormatVersion { - /// The version that was encountered. - value: u32, - }, - #[error("invalid value {ftype} for `f16` in hyperparameters")] - /// The `f16` hyperparameter had an invalid value. - HyperparametersF16Invalid { - /// The format type that was encountered. - ftype: u32, - }, - #[error("unknown tensor `{tensor_name}` in {path:?}")] - /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during - /// the model prelude. - UnknownTensor { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] - /// The tensor `tensor_name` did not match its expected size. - TensorWrongSize { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - /// The tensor `tensor_name` did not have the expected format type. - #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] - InvalidFtype { - /// The name of the tensor. - tensor_name: String, - /// The format type that was encountered. - ftype: i32, - /// The path that failed. - path: PathBuf, - }, -} -impl From for LoadError { - fn from(value: util::FindAllModelFilesError) -> Self { - match value { - util::FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path }, - util::FindAllModelFilesError::IO(err) => LoadError::IO(err), - } - } -} pub(crate) fn load( path: impl AsRef, @@ -178,7 +47,7 @@ pub(crate) fn load( ContainerType::GGMF | ContainerType::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), + version => return Err(LoadError::InvalidFormatVersion { version }), }; } ContainerType::GGML => {} diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 757528fd..6440589c 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -1,34 +1,285 @@ -//! ggml-loader aux +//! This is an experimental, *incomplete* implementation of a loader based on `ggml_loader`. +//! +//! At the time of writing, it does not successfully load any models. +//! +//! GGML/GGMF fails with an invariant broken error, and GGJT fails with an unexpected state error. +//! +//! It also does not support mmap, but it shouldn't be too hard to add: mmap as is done in `loader`, then populate +//! the tensor from the [TensorInfo]. use ggml_loader::util::*; use ggml_loader::*; -/// The hyperparameters of the model. -#[derive(Debug, Clone)] -pub struct LlamaHyperparameters { - pub n_vocab: usize, - pub n_embd: usize, - pub n_mult: usize, - pub n_head: usize, - pub n_layer: usize, - pub n_rot: usize, - pub tensor_element_type: ElementType, +use std::{ + fs::File, + io::{BufRead, BufReader, Seek}, + ops::ControlFlow, + path::{Path, PathBuf}, +}; + +use crate::{ + util::mulf, Hyperparameters, LoadError, LoadProgress, Model, TokenId, UnexpectedState, + Vocabulary, +}; + +impl LoadError { + fn from_ggml_loader_error(value: ggml_loader::LoadError, path: PathBuf) -> Self { + match value { + ggml_loader::LoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, + ggml_loader::LoadError::InvalidFormatVersion(version) => { + LoadError::InvalidFormatVersion { version } + } + ggml_loader::LoadError::Io(err) => LoadError::Io(err), + ggml_loader::LoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), + ggml_loader::LoadError::UserInterrupted(err) => err, + ggml_loader::LoadError::UnsupportedElementType(ty) => { + LoadError::HyperparametersF16Invalid { + ftype: ty.try_into().unwrap(), + } + } + ggml_loader::LoadError::InvariantBroken(invariant) => { + LoadError::InvariantBroken { path, invariant } + } + } + } +} + +pub(crate) fn load( + path: impl AsRef, + n_context_tokens: usize, + load_progress_callback: impl FnMut(LoadProgress), +) -> Result { + let main_path = path.as_ref(); + + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?; + let mut reader = BufReader::new(&file); + + let path = path.as_ref().to_owned(); + let mut loader = Loader { + path: path.clone(), + state: LoadState::Vocabulary(Vocabulary::default()), + hyperparameters: Hyperparameters::default(), + container_type: ContainerType::GGJT, + load_progress_callback, + n_ctx: n_context_tokens, + }; + + ggml_loader::load_model_from_reader(&mut reader, &mut loader) + .map_err(|err| LoadError::from_ggml_loader_error(err, path.clone()))?; + + match loader.state { + LoadState::Vocabulary(_) => Err(LoadError::UnexpectedState { + path, + state: UnexpectedState::Vocabulary, + context: "Encountered vocabulary state while finalizing model".to_string(), + }), + LoadState::Model(model) => Ok(model), + } +} + +enum LoadState { + Vocabulary(Vocabulary), + Model(Model), +} +struct Loader { + // Context + path: PathBuf, + n_ctx: usize, + load_progress_callback: F, + + // Internal state + hyperparameters: Hyperparameters, + container_type: ContainerType, + + state: LoadState, +} +impl ggml_loader::LoadHandler> for Loader { + fn load_hyper_parameters( + &mut self, + reader: &mut BufReader<&File>, + ) -> ControlFlow { + let (hyperparameters, partial) = match load_hyperparameters(reader, self.n_ctx) { + Ok(t) => t, + Err(err) => { + return ControlFlow::Break(LoadError::from_ggml_loader_error( + err, + self.path.clone(), + )) + } + }; + self.hyperparameters = hyperparameters; + (self.load_progress_callback)(LoadProgress::HyperparametersLoaded(&self.hyperparameters)); + + ControlFlow::Continue(partial) + } + + fn got_container_type(&mut self, model_type: ContainerType) -> ControlFlow { + self.container_type = model_type; + ControlFlow::Continue(()) + } + + fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { + let vocab = match &mut self.state { + LoadState::Vocabulary(v) => v, + LoadState::Model(_) => { + return ControlFlow::Break(LoadError::UnexpectedState { + path: self.path.clone(), + state: UnexpectedState::Model, + context: "Encountered model state while loading vocabulary".to_string(), + }) + } + }; + vocab.max_token_length = vocab.max_token_length.max(token.len()); + vocab.id_to_token.push(token.clone()); + let id = match TokenId::try_from(i) { + Ok(id) => id, + Err(err) => return ControlFlow::Break(LoadError::InvalidIntegerConversion(err)), + }; + vocab.token_to_id.insert(token, id); + vocab.id_to_token_score.push(score); + + ControlFlow::Continue(()) + } + + fn load_multipart(&mut self, _reader: &mut BufReader<&File>) -> ControlFlow { + // TODO: implement multipart loading + + (self.load_progress_callback)(LoadProgress::PartLoading { + file: &self.path, + current_part: 0, + total_parts: 1, + }); + + let vocabulary = match &self.state { + LoadState::Vocabulary(v) => v.clone(), + LoadState::Model(_) => { + return ControlFlow::Break(LoadError::UnexpectedState { + path: self.path.clone(), + state: UnexpectedState::Model, + context: "Encountered model state while transitioning into model state" + .to_string(), + }) + } + }; + let alloc = !(cfg!(feature = "mmap") && self.container_type == ContainerType::GGJT); + + let Hyperparameters { + n_vocab, + n_embd, + n_mult, + n_layer, + element_type, + .. + } = self.hyperparameters; + + let n_ff = ((2 * (4 * n_embd) / 3 + n_mult - 1) / n_mult) * n_mult; + let wtype = element_type; + + let ctx_size = { + // Use 64-bit math to prevent overflow. + let mut ctx_size: usize = (5 + 10 * n_layer) * 256; // object overhead + + if alloc { + let mut model_size: usize = 0; + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm + + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm + + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 + + ctx_size += model_size; + } + + (self.load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size }); + + ctx_size + }; + + // Initialize the context + let context = ggml::Context::init(ctx_size, alloc); + + self.state = LoadState::Model(Model::new( + context, + self.hyperparameters, + vocabulary, + n_ff, + wtype, + self.container_type, + )); + ControlFlow::Continue(()) + } + + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { + let model = match &mut self.state { + LoadState::Model(m) => m, + LoadState::Vocabulary(_) => { + return ControlFlow::Break(LoadError::UnexpectedState { + path: self.path.clone(), + state: UnexpectedState::Vocabulary, + context: "Encountered vocabulary state while populating tensors".to_string(), + }) + } + }; + + let tensor_name = match String::from_utf8(info.name) { + Ok(n) => n, + Err(err) => return ControlFlow::Break(LoadError::InvalidUtf8(err)), + }; + + let tensor = match model.tensors_mut().get_mut(&tensor_name) { + Some(tensor) => tensor, + None => { + return ControlFlow::Break(LoadError::UnknownTensor { + path: self.path.clone(), + tensor_name, + }) + } + }; + + let buf: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + + (self.load_progress_callback)(LoadProgress::PartTensorLoaded { + file: &self.path, + // TODO: keep track of tensors loaded + current_tensor: 0, + tensor_count: model.tensors_mut().len(), + }); + + ControlFlow::Continue(Some(buf)) + } } /// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] -#[allow(dead_code)] -pub fn load_llama_hparams( +fn load_hyperparameters( reader: &mut R, -) -> Result<(LlamaHyperparameters, PartialHyperparameters), LoadError> { + n_ctx: usize, +) -> Result<(Hyperparameters, PartialHyperparameters), ggml_loader::LoadError> { // NOTE: Field order matters! Data is laid out in the file exactly in this order. - let hparams = LlamaHyperparameters { + let hparams = Hyperparameters { n_vocab: read_i32(reader)?.try_into()?, n_embd: read_i32(reader)?.try_into()?, n_mult: read_i32(reader)?.try_into()?, n_head: read_i32(reader)?.try_into()?, n_layer: read_i32(reader)?.try_into()?, n_rot: read_i32(reader)?.try_into()?, - tensor_element_type: decode_element_type_res(read_i32(reader)?)?, + element_type: decode_element_type_res(read_i32(reader)?)?, + n_ctx, }; let partial = PartialHyperparameters { n_vocab: hparams.n_vocab, diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs new file mode 100644 index 00000000..bea806de --- /dev/null +++ b/llama-rs/src/loader_common.rs @@ -0,0 +1,168 @@ +use std::path::{Path, PathBuf}; + +use thiserror::Error; + +use crate::{util::FindAllModelFilesError, Hyperparameters}; + +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum LoadProgress<'a> { + /// The hyperparameters have been loaded from the model. + HyperparametersLoaded(&'a Hyperparameters), + /// The context has been created. + ContextSize { + /// The size of the context. + bytes: usize, + }, + /// A part of the model is being loaded. + PartLoading { + /// The path to the model part. + file: &'a Path, + /// The current part (0-indexed). + current_part: usize, + /// The number of total parts. + total_parts: usize, + }, + /// A tensor from the current part has been loaded. + PartTensorLoaded { + /// The path to the model part. + file: &'a Path, + /// The current tensor (0-indexed). + current_tensor: usize, + /// The number of total tensors. + tensor_count: usize, + }, + /// A model part has finished fully loading. + PartLoaded { + /// The path to the model part. + file: &'a Path, + /// The number of bytes in the part. + byte_size: usize, + /// The number of tensors in the part. + tensor_count: usize, + }, +} + +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum LoadError { + #[error("could not open file {path:?}")] + /// A file failed to open. + OpenFileFailed { + /// The original error. + source: std::io::Error, + /// The path that failed. + path: PathBuf, + }, + #[error("no parent path for {path:?}")] + /// There is no parent path for a given path. + NoParentPath { + /// The path without a parent. + path: PathBuf, + }, + #[error("unable to read exactly {bytes} bytes")] + /// Reading exactly `bytes` from a file failed. + ReadExactFailed { + /// The original error. + source: std::io::Error, + /// The number of bytes that were attempted to be read. + bytes: usize, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("unsupported f16_: {0}")] + /// One of the integers encountered could not be converted to a more appropriate type. + UnsupportedElementType(i32), + #[error("invalid magic number for {path:?}")] + /// An invalid magic number was encountered during the loading process. + InvalidMagic { + /// The path that failed. + path: PathBuf, + }, + #[error("invalid file format version {version}")] + /// The version of the format is not supported by this version of `llama-rs`. + InvalidFormatVersion { + /// The version that was encountered. + version: u32, + }, + #[error("invalid value {ftype} for `f16` in hyperparameters")] + /// The `f16` hyperparameter had an invalid value. + HyperparametersF16Invalid { + /// The format type that was encountered. + ftype: u32, + }, + #[error("unknown tensor `{tensor_name}` in {path:?}")] + /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during + /// the model prelude. + UnknownTensor { + /// The name of the tensor. + tensor_name: String, + /// The path that failed. + path: PathBuf, + }, + #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] + /// The tensor `tensor_name` did not match its expected size. + TensorWrongSize { + /// The name of the tensor. + tensor_name: String, + /// The path that failed. + path: PathBuf, + }, + /// The tensor `tensor_name` did not have the expected format type. + #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] + InvalidFtype { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: i32, + /// The path that failed. + path: PathBuf, + }, + /// An invariant was broken. + /// + /// This error is not relevant unless `loader2` is being used. + #[error("invariant broken: {invariant} in {path:?}")] + InvariantBroken { + /// The path that failed. + path: PathBuf, + /// The invariant that was broken. + invariant: String, + }, + /// The loader was in an unexpected state. + /// + /// This error is not relevant unless `loader2` is being used. + #[error("unexpected state {state:?} in {path:?}: {context}")] + UnexpectedState { + /// The path that failed. + path: PathBuf, + /// The state that was encountered. + state: UnexpectedState, + /// Context about what was expected. + context: String, + }, +} +impl From for LoadError { + fn from(value: FindAllModelFilesError) -> Self { + match value { + FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path }, + FindAllModelFilesError::IO(err) => LoadError::Io(err), + } + } +} + +#[derive(Debug)] +/// The state that the loader was in when an error was encountered. +pub enum UnexpectedState { + /// The loader was in the `Vocabulary` state. + Vocabulary, + /// The loader was in the `Model` state. + Model, +} diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index b4cfa017..4ca4db1a 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, path::Path}; use crate::{ - loader, vocabulary::TokenId, EvaluateOutputRequest, InferenceParameters, InferenceSession, - InferenceSessionParameters, LoadError, LoadProgress, Mmap, Vocabulary, + loader, loader2, vocabulary::TokenId, EvaluateOutputRequest, InferenceParameters, + InferenceSession, InferenceSessionParameters, LoadError, LoadProgress, Mmap, Vocabulary, }; use ggml_loader::ContainerType; @@ -120,7 +120,13 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { - loader::load(path, n_context_tokens, load_progress_callback) + const USE_LOADER_2: bool = false; + + if USE_LOADER_2 { + loader2::load(path, n_context_tokens, load_progress_callback) + } else { + loader::load(path, n_context_tokens, load_progress_callback) + } } /// Starts a new `InferenceSession` for this model. @@ -438,7 +444,7 @@ impl Model { } /// The hyperparameters of the model. -#[derive(Debug, Default, PartialEq, Eq)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Hyperparameters { /// n_vocab pub n_vocab: usize, diff --git a/llama-rs/src/util.rs b/llama-rs/src/util.rs index 3044bdb6..4ada4d22 100644 --- a/llama-rs/src/util.rs +++ b/llama-rs/src/util.rs @@ -1,7 +1,5 @@ use std::path::{Path, PathBuf}; -use crate::LoadError; - /// NOTE: The original code relies in promotion rules and automatic cast between /// int to float. What we do instead is use this macro to convert every term of /// the multiplication to f64, which should have enough precision bits to hold diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index 80e619c7..c87b49fb 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -8,7 +8,7 @@ pub(crate) type Token = Vec; pub(crate) type TokenScore = f32; /// The vocabulary used by a model. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Vocabulary { /// Maps every integer (index) token id to its corresponding token pub(crate) id_to_token: Vec, From 94951c463926edc16c7c440efc69dd7b2d0f1171 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 20 Apr 2023 00:50:32 +0200 Subject: [PATCH 26/42] fix dead doc link --- llama-rs/src/inference_session.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-rs/src/inference_session.rs b/llama-rs/src/inference_session.rs index 90179df2..428b9a7b 100644 --- a/llama-rs/src/inference_session.rs +++ b/llama-rs/src/inference_session.rs @@ -473,7 +473,7 @@ impl InferenceSnapshotRef<'_> { } /// A serializable snapshot of the inference process. Can be restored by calling -/// [Model::session_from_snapshot]. +/// [InferenceSession::from_snapshot]. #[derive(serde::Deserialize, Clone, PartialEq)] // Keep in sync with [InferenceSession] and [InferenceSnapshotRef]. pub struct InferenceSnapshot { From 69f355b999c11cf670348e65eb1d3a98aa52ff0f Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 20 Apr 2023 01:16:26 +0200 Subject: [PATCH 27/42] feat: turn mmap on by default, add --no-mmap --- llama-cli/Cargo.toml | 3 -- llama-cli/src/cli_args.rs | 96 +++++++++++++++++++++------------------ llama-rs/Cargo.toml | 7 +-- llama-rs/src/lib.rs | 19 -------- llama-rs/src/loader.rs | 37 +++++++++------ llama-rs/src/loader2.rs | 14 ++++-- llama-rs/src/model.rs | 14 ++++-- 7 files changed, 96 insertions(+), 94 deletions(-) diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 72727e0f..b6de277f 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -19,6 +19,3 @@ once_cell = "1.17.1" rustyline = "11.0.0" spinners = "4.1.0" zstd = { version = "0.12", default-features = false } - -[features] -mmap = ["llama-rs/mmap"] \ No newline at end of file diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 608b8647..7c9fe179 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -258,57 +258,67 @@ pub struct ModelLoad { /// will likely not perform as well as a model with a larger context size. #[arg(long, default_value_t = 2048)] pub num_ctx_tokens: usize, + + /// Don't use mmap to load the model. + #[arg(long)] + pub no_mmap: bool, } impl ModelLoad { pub fn load(&self) -> llama_rs::Model { let now = std::time::Instant::now(); - let model = llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| { - use llama_rs::LoadProgress; - match progress { - LoadProgress::HyperparametersLoaded(hparams) => { - log::debug!("Loaded hyperparameters {hparams:#?}") - } - LoadProgress::ContextSize { bytes } => log::info!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::PartLoading { - file, - current_part, - total_parts, - } => { - let current_part = current_part + 1; - log::info!( - "Loading model part {}/{} from '{}'\n", + let model = llama_rs::Model::load( + &self.model_path, + !self.no_mmap, + self.num_ctx_tokens, + |progress| { + use llama_rs::LoadProgress; + match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded hyperparameters {hparams:#?}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::PartLoading { + file, current_part, total_parts, - file.to_string_lossy(), - ) - } - LoadProgress::PartTensorLoaded { - current_tensor, - tensor_count, - .. - } => { - let current_tensor = current_tensor + 1; - if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } => { + let current_part = current_part + 1; + log::info!( + "Loading model part {}/{} from '{}' (mmap: {})\n", + current_part, + total_parts, + file.to_string_lossy(), + !self.no_mmap + ) + } + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + let current_tensor = current_tensor + 1; + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); } } - LoadProgress::PartLoaded { - file, - byte_size, - tensor_count, - } => { - log::info!("Loading of '{}' complete", file.to_string_lossy()); - log::info!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - } - }) + }, + ) .expect("Could not load model"); log::info!( diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 4e854a43..c746e388 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -17,7 +17,7 @@ partial_sort = "0.2.0" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" -memmap2 = { version = "0.5.10", optional = true } +memmap2 = "0.5.10" # Used for the `convert` feature serde_json = { version = "1.0", optional = true } @@ -25,7 +25,4 @@ protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] - -# broken atm, see https://github.com/rustformers/llama-rs/pull/114#issuecomment-1500337463 -mmap = ["dep:memmap2"] +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index dc91f864..ff35e29e 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -24,25 +24,6 @@ pub use model::{Hyperparameters, Model}; pub use util::TokenUtf8Buffer; pub use vocabulary::{TokenBias, TokenId, Vocabulary}; -#[cfg(feature = "mmap")] -use memmap2::Mmap; - -/// dummy struct -#[cfg(not(feature = "mmap"))] -pub(crate) struct Mmap; - -/// dummy impl -#[cfg(not(feature = "mmap"))] -impl Mmap { - pub(crate) unsafe fn map(_: &std::fs::File) -> std::io::Result { - Ok(Mmap) - } - pub(crate) fn as_ptr(&self) -> *const u8 { - std::ptr::null() - } -} -// map - /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 2b2ad656..b3441be5 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -8,14 +8,16 @@ use std::{ use crate::{ util::{self, mulf}, - LoadError, LoadProgress, Mmap, Model, TokenId, Vocabulary, + LoadError, LoadProgress, Model, TokenId, Vocabulary, }; use crate::{ElementType, Hyperparameters}; use ggml_loader::util::*; use ggml_loader::ContainerType; +use memmap2::Mmap; pub(crate) fn load( path: impl AsRef, + use_mmap: bool, n_context_tokens: usize, mut load_progress_callback: impl FnMut(LoadProgress), ) -> Result { @@ -124,7 +126,7 @@ pub(crate) fn load( let n_layer = hparams.n_layer; let n_vocab = hparams.n_vocab; - let alloc = !(cfg!(feature = "mmap") && model_type == ContainerType::GGJT); + let alloc = !(use_mmap && model_type == ContainerType::GGJT); let ctx_size = { // Use 64-bit math to prevent overflow. @@ -161,7 +163,15 @@ pub(crate) fn load( // Initialize the context let context = ggml::Context::init(ctx_size, alloc); - let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type); + let (mmap, mmap_ptr) = if use_mmap && model_type == ContainerType::GGJT { + let mmap = unsafe { Mmap::map(&file)? }; + let ptr = mmap.as_ptr(); + (Some(mmap), Some(ptr)) + } else { + (None, None) + }; + + let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type, mmap); match model_type { ContainerType::GGMF | ContainerType::GGML => { let file_offset = reader.stream_position()?; @@ -174,12 +184,9 @@ pub(crate) fn load( )? } ContainerType::GGJT => { - let mmap = unsafe { Mmap::map(&file)? }; - let ptr = mmap.as_ptr(); - model.mmap = Some(mmap); load_weights_ggjt( &mut reader, - ptr, + mmap_ptr, main_path, load_progress_callback, model.tensors_mut(), @@ -438,7 +445,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { fn load_weights_ggjt( reader: &mut (impl BufRead + Seek), - mmap_base: *const u8, + mmap_base: Option<*const u8>, path: &Path, mut load_progress_callback: impl FnMut(LoadProgress), tensors: &mut HashMap, @@ -502,7 +509,11 @@ fn load_weights_ggjt( } }; - load_tensor_ggjt(reader, mmap_base, tensor)?; + if let Some(mmap_base) = mmap_base { + load_tensor_ggjt_mmap(reader, mmap_base, tensor)?; + } else { + load_tensor_ggjt_copy(reader, tensor)?; + } total_loaded_bytes += tensor.nbytes() as u64; @@ -524,8 +535,7 @@ fn load_weights_ggjt( Ok(()) } -#[cfg(feature = "mmap")] -fn load_tensor_ggjt( +fn load_tensor_ggjt_mmap( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, tensor: &mut ggml::Tensor, @@ -540,13 +550,10 @@ fn load_tensor_ggjt( Ok(()) } -#[cfg(not(feature = "mmap"))] -fn load_tensor_ggjt<'a>( +fn load_tensor_ggjt_copy<'a>( reader: &mut (impl BufRead + Seek), - mmap_base: *const u8, tensor: &'a mut ggml::Tensor, ) -> Result<(), LoadError> { - _ = mmap_base; let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; reader.seek(SeekFrom::Start(offset_aligned))?; diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 6440589c..69aac083 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -46,6 +46,7 @@ impl LoadError { pub(crate) fn load( path: impl AsRef, + use_mmap: bool, n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { @@ -60,11 +61,14 @@ pub(crate) fn load( let path = path.as_ref().to_owned(); let mut loader = Loader { path: path.clone(), - state: LoadState::Vocabulary(Vocabulary::default()), + n_ctx: n_context_tokens, + load_progress_callback, + use_mmap, + hyperparameters: Hyperparameters::default(), container_type: ContainerType::GGJT, - load_progress_callback, - n_ctx: n_context_tokens, + + state: LoadState::Vocabulary(Vocabulary::default()), }; ggml_loader::load_model_from_reader(&mut reader, &mut loader) @@ -89,6 +93,7 @@ struct Loader { path: PathBuf, n_ctx: usize, load_progress_callback: F, + use_mmap: bool, // Internal state hyperparameters: Hyperparameters, @@ -164,7 +169,7 @@ impl ggml_loader::LoadHandler ggml_loader::LoadHandler, - pub(crate) mmap: Option, + /// Needs to kept alive while the model is alive + _mmap: Option, _version: ContainerType, @@ -38,6 +40,7 @@ impl Model { n_ff: usize, wtype: ggml::Type, model_type: ContainerType, + mmap: Option, ) -> Model { let n_embd = hparams.n_embd; let n_layer = hparams.n_layer; @@ -107,7 +110,7 @@ impl Model { layers, tensors, _context: context, - mmap: None, + _mmap: mmap, _version: model_type, } } @@ -117,15 +120,16 @@ impl Model { /// The status of the loading process will be reported through `load_progress_callback`. pub fn load( path: impl AsRef, + use_mmap: bool, n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { const USE_LOADER_2: bool = false; if USE_LOADER_2 { - loader2::load(path, n_context_tokens, load_progress_callback) + loader2::load(path, use_mmap, n_context_tokens, load_progress_callback) } else { - loader::load(path, n_context_tokens, load_progress_callback) + loader::load(path, use_mmap, n_context_tokens, load_progress_callback) } } From 17bc0cc8cfb5d2fae4d3a3f92035a7a20acc0ca5 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 20 Apr 2023 16:54:26 +0000 Subject: [PATCH 28/42] Fix loading GGJT --- Cargo.lock | 1 + ggml-loader/src/lib.rs | 7 -- ggml-loader/src/util.rs | 15 ++++ llama-rs/Cargo.toml | 3 +- llama-rs/src/lib.rs | 2 +- llama-rs/src/loader2.rs | 156 ++++++++++++++----------------------- llama-rs/src/model.rs | 4 +- llama-rs/src/vocabulary.rs | 38 +++++++++ 8 files changed, 117 insertions(+), 109 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6bb4e48e..2fa33e94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -549,6 +549,7 @@ dependencies = [ "bytemuck", "ggml", "ggml-loader", + "log", "memmap2", "partial_sort", "protobuf", diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index b490868a..761977ca 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -101,13 +101,6 @@ fn can_be_vtable() { let _a: MaybeUninit>> = MaybeUninit::uninit(); } -fn retchk(model_type: ControlFlow) -> Result> { - match model_type { - ControlFlow::Continue(x) => Ok(x), - ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), - } -} - pub fn load_model_from_reader( reader: &mut R, handler: &mut impl LoadHandler, diff --git a/ggml-loader/src/util.rs b/ggml-loader/src/util.rs index 342f3bd0..61852fbf 100644 --- a/ggml-loader/src/util.rs +++ b/ggml-loader/src/util.rs @@ -1,4 +1,5 @@ pub use std::io::{BufRead, Seek, SeekFrom}; +use std::ops::ControlFlow; use crate::{ElementType, LoadError}; @@ -60,3 +61,17 @@ pub fn decode_element_type_res(ftype: i32) -> Result Err(LoadError::UnsupportedElementType(ftype)), } } + +pub fn retchk(model_type: ControlFlow) -> Result> { + match model_type { + ControlFlow::Continue(x) => Ok(x), + ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), + } +} + +pub fn brkchk(model_type: Result) -> ControlFlow { + match model_type { + Ok(x) => ControlFlow::Continue(x), + Err(y) => ControlFlow::Break(y), + } +} diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index c746e388..6fecd50d 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -23,6 +23,7 @@ memmap2 = "0.5.10" serde_json = { version = "1.0", optional = true } protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } +log = "*" [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index ff35e29e..3f0a6c69 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -19,7 +19,7 @@ pub use inference_session::{ InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType, SnapshotError, }; -pub use loader_common::{LoadError, LoadProgress, UnexpectedState}; +pub use loader_common::{LoadError, LoadProgress}; pub use model::{Hyperparameters, Model}; pub use util::TokenUtf8Buffer; pub use vocabulary::{TokenBias, TokenId, Vocabulary}; diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 69aac083..5c5ace90 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -17,10 +17,7 @@ use std::{ path::{Path, PathBuf}, }; -use crate::{ - util::mulf, Hyperparameters, LoadError, LoadProgress, Model, TokenId, UnexpectedState, - Vocabulary, -}; +use crate::{util::mulf, Hyperparameters, LoadError, LoadProgress, Model, TokenId, Vocabulary}; impl LoadError { fn from_ggml_loader_error(value: ggml_loader::LoadError, path: PathBuf) -> Self { @@ -61,46 +58,36 @@ pub(crate) fn load( let path = path.as_ref().to_owned(); let mut loader = Loader { path: path.clone(), + vocab: Default::default(), + model: None, n_ctx: n_context_tokens, load_progress_callback, use_mmap, hyperparameters: Hyperparameters::default(), container_type: ContainerType::GGJT, - - state: LoadState::Vocabulary(Vocabulary::default()), }; ggml_loader::load_model_from_reader(&mut reader, &mut loader) .map_err(|err| LoadError::from_ggml_loader_error(err, path.clone()))?; - match loader.state { - LoadState::Vocabulary(_) => Err(LoadError::UnexpectedState { - path, - state: UnexpectedState::Vocabulary, - context: "Encountered vocabulary state while finalizing model".to_string(), - }), - LoadState::Model(model) => Ok(model), - } + Ok(loader.model.expect("model should be initialized")) } -enum LoadState { - Vocabulary(Vocabulary), - Model(Model), -} struct Loader { - // Context + // input data and options path: PathBuf, n_ctx: usize, - load_progress_callback: F, use_mmap: bool, // Internal state - hyperparameters: Hyperparameters, container_type: ContainerType, - - state: LoadState, + hyperparameters: Hyperparameters, + model: Option, + vocab: Vocabulary, + load_progress_callback: F, } + impl ggml_loader::LoadHandler> for Loader { fn load_hyper_parameters( &mut self, @@ -127,50 +114,68 @@ impl ggml_loader::LoadHandler, score: f32) -> ControlFlow { - let vocab = match &mut self.state { - LoadState::Vocabulary(v) => v, - LoadState::Model(_) => { - return ControlFlow::Break(LoadError::UnexpectedState { - path: self.path.clone(), - state: UnexpectedState::Model, - context: "Encountered model state while loading vocabulary".to_string(), - }) - } - }; - vocab.max_token_length = vocab.max_token_length.max(token.len()); - vocab.id_to_token.push(token.clone()); let id = match TokenId::try_from(i) { Ok(id) => id, Err(err) => return ControlFlow::Break(LoadError::InvalidIntegerConversion(err)), }; - vocab.token_to_id.insert(token, id); - vocab.id_to_token_score.push(score); + self.vocab + .push_token(id, token, score) + .expect("vocab should be valid"); ControlFlow::Continue(()) } fn load_multipart(&mut self, _reader: &mut BufReader<&File>) -> ControlFlow { - // TODO: implement multipart loading + // todo + log::warn!("multipart model is not supported"); + ControlFlow::Continue(()) + } - (self.load_progress_callback)(LoadProgress::PartLoading { - file: &self.path, - current_part: 0, - total_parts: 1, - }); + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { + if self.model.is_none() { + self.model = Some(self.create_model(self.vocab.clone())); + } + let model = &mut self.model.as_mut().expect("initialized"); - let vocabulary = match &self.state { - LoadState::Vocabulary(v) => v.clone(), - LoadState::Model(_) => { - return ControlFlow::Break(LoadError::UnexpectedState { + let tensor_name = match String::from_utf8(info.name) { + Ok(n) => n, + Err(err) => return ControlFlow::Break(LoadError::InvalidUtf8(err)), + }; + + let tensor = match model.tensors_mut().get_mut(&tensor_name) { + Some(tensor) => tensor, + None => { + return ControlFlow::Break(LoadError::UnknownTensor { path: self.path.clone(), - state: UnexpectedState::Model, - context: "Encountered model state while transitioning into model state" - .to_string(), + tensor_name, }) } }; - let alloc = !(self.use_mmap && self.container_type == ContainerType::GGJT); + // todo: support mmap + let buf: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + + let tensor_count = model.tensors_mut().len(); + (self.load_progress_callback)(LoadProgress::PartTensorLoaded { + file: &self.path, + // TODO: keep track of tensors loaded + current_tensor: 0, + tensor_count, + }); + + ControlFlow::Continue(Some(buf)) + } +} + +impl Loader { + fn create_model(&mut self, vocabulary: Vocabulary) -> Model { + (self.load_progress_callback)(LoadProgress::PartLoading { + file: &self.path, + current_part: 0, + total_parts: 1, + }); + let alloc = !(self.use_mmap && self.container_type == ContainerType::GGJT); let Hyperparameters { n_vocab, n_embd, @@ -179,10 +184,8 @@ impl ggml_loader::LoadHandler ggml_loader::LoadHandler ggml_loader::LoadHandler ControlFlow> { - let model = match &mut self.state { - LoadState::Model(m) => m, - LoadState::Vocabulary(_) => { - return ControlFlow::Break(LoadError::UnexpectedState { - path: self.path.clone(), - state: UnexpectedState::Vocabulary, - context: "Encountered vocabulary state while populating tensors".to_string(), - }) - } - }; - - let tensor_name = match String::from_utf8(info.name) { - Ok(n) => n, - Err(err) => return ControlFlow::Break(LoadError::InvalidUtf8(err)), - }; - - let tensor = match model.tensors_mut().get_mut(&tensor_name) { - Some(tensor) => tensor, - None => { - return ControlFlow::Break(LoadError::UnknownTensor { - path: self.path.clone(), - tensor_name, - }) - } - }; - - let buf: &mut [u8] = - unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; - - (self.load_progress_callback)(LoadProgress::PartTensorLoaded { - file: &self.path, - // TODO: keep track of tensors loaded - current_tensor: 0, - tensor_count: model.tensors_mut().len(), - }); - - ControlFlow::Continue(Some(buf)) + ) } } diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index d37e97b2..0b547d2b 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -124,9 +124,9 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { - const USE_LOADER_2: bool = false; + let use_loader_2: bool = std::env::var("USE_LOADER_2").is_ok(); - if USE_LOADER_2 { + if use_loader_2 { loader2::load(path, use_mmap, n_context_tokens, load_progress_callback) } else { loader::load(path, use_mmap, n_context_tokens, load_progress_callback) diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index c87b49fb..974d24c0 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, str::FromStr}; +use thiserror::Error; + use crate::InferenceError; /// The identifier of a token in a vocabulary. @@ -16,13 +18,49 @@ pub struct Vocabulary { /// Maps every integer (index) token id to corresponding score pub(crate) id_to_token_score: Vec, + // todo: use a radix tree /// Maps a token to a token id pub(crate) token_to_id: HashMap, /// The longest token in this vocabulary pub(crate) max_token_length: usize, } + +#[derive(Debug, Clone, Error)] +pub enum AddTokenError { + #[error("the id of token added should be {expected_id}; is {actual_id}")] + WrongId { + expected_id: TokenId, + actual_id: TokenId, + }, + #[error("a token with the same id already exists, id={id}")] + AlreadyAdded { id: TokenId }, +} + impl Vocabulary { + /// Add a token to the vocabulary. + /// + /// The token added must have `id` directly after the last token in the vocabulary. + pub fn push_token( + &mut self, + id: TokenId, + content: Token, + score: TokenScore, + ) -> Result<(), AddTokenError> { + assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); + if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize { + return Err(AddTokenError::WrongId { + expected_id: self.id_to_token.len() as TokenId, + actual_id: id, + }); + } + self.max_token_length = self.max_token_length.max(content.len()); + self.id_to_token.push(content.clone()); + self.id_to_token_score.push(score); + self.token_to_id.insert(content, id); + Ok(()) + } + pub(crate) fn token(&self, idx: usize) -> &[u8] { &self.id_to_token[idx] } From 6641ae901d549172799bf74ffb68fd2edfd6e6ec Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 20 Apr 2023 17:39:35 +0000 Subject: [PATCH 29/42] minor fix --- ggml-loader/src/lib.rs | 9 +-------- llama-rs/src/loader2.rs | 2 +- llama-rs/src/loader_common.rs | 2 +- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index 761977ca..8c41f678 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -149,17 +149,10 @@ pub fn load_model_from_reader( retchk(handler.load_multipart(reader))?; load_weights(reader, handler, false) } - ContainerType::GGJT => load_weights_ggjt(reader, handler), + ContainerType::GGJT => load_weights(reader, handler, true), } } -pub fn load_weights_ggjt( - reader: &mut R, - handler: &mut impl LoadHandler, -) -> Result<(), LoadError> { - load_weights(reader, handler, true) -} - /// # Params /// /// `align` diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 5c5ace90..b7c9f365 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -31,7 +31,7 @@ impl LoadError { ggml_loader::LoadError::UserInterrupted(err) => err, ggml_loader::LoadError::UnsupportedElementType(ty) => { LoadError::HyperparametersF16Invalid { - ftype: ty.try_into().unwrap(), + ftype: ty, } } ggml_loader::LoadError::InvariantBroken(invariant) => { diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index bea806de..1c264c88 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -97,7 +97,7 @@ pub enum LoadError { /// The `f16` hyperparameter had an invalid value. HyperparametersF16Invalid { /// The format type that was encountered. - ftype: u32, + ftype: i32, }, #[error("unknown tensor `{tensor_name}` in {path:?}")] /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during From 3910b6a57b1bf7b538825ae3552ebf5c0295f47c Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 20 Apr 2023 18:02:18 +0000 Subject: [PATCH 30/42] Add mmap --- ggml-loader/src/lib.rs | 51 ++++++++++++-------- ggml-loader/src/util.rs | 10 ++-- llama-cli/src/cli_args.rs | 4 +- llama-rs/src/loader.rs | 6 +-- llama-rs/src/loader2.rs | 98 ++++++++++++++++++++++++++------------- llama-rs/src/model.rs | 14 +++--- 6 files changed, 115 insertions(+), 68 deletions(-) diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index 8c41f678..45b88058 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -22,6 +22,16 @@ pub enum ContainerType { GGJT, } +impl ContainerType { + pub fn support_mmap(&self) -> bool { + match self { + ContainerType::GGML => false, + ContainerType::GGMF => false, + ContainerType::GGJT => true, + } + } +} + #[derive(Debug, thiserror::Error)] pub enum LoadError { #[error("invalid file magic number: {0}")] @@ -66,9 +76,17 @@ pub struct PartialHyperparameters { pub n_vocab: usize, } +pub enum TensorDataTreatment<'a> { + CopyInto(&'a mut [u8]), + SeekPast { + /// should be `tensor.nbytes` + n_bytes: usize + }, +} + #[allow(unused_variables)] pub trait LoadHandler { - fn got_container_type(&mut self, model_type: ContainerType) -> ControlFlow { + fn got_container_type(&mut self, container_type: ContainerType) -> ControlFlow { ControlFlow::Continue(()) } @@ -90,9 +108,7 @@ pub trait LoadHandler { /// /// `None` to skip copying /// `Some(buf)` to provide a buffer for copying weights into - fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { - ControlFlow::Continue(None) - } + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; } #[test] @@ -164,6 +180,9 @@ pub fn load_weights( ) -> Result<(), LoadError> { while has_data_left(reader)? { // load tensor header + let start_pos = reader.stream_position()?; + dbg!(start_pos); + let n_dims: usize = read_i32(reader)?.try_into()?; let name_len = read_i32(reader)?; let ftype = decode_element_type_res(read_i32(reader)?)?; @@ -211,23 +230,17 @@ pub fn load_weights( start_offset: offset_aligned, }; - let type_size = ggml::type_size(ftype); - if let Some(buf) = retchk(handler.tensor_buffer(tensor_info))? { - if align { - reader.seek(SeekFrom::Start(offset_aligned))?; + match retchk(handler.tensor_buffer(tensor_info))? { + TensorDataTreatment::CopyInto(buf) => { + if align { + reader.seek(SeekFrom::Start(offset_aligned))?; + } + reader.read_exact(buf)?; } - let buf_len = buf.len(); - if !(buf_len == type_size * n_elements) { - return Err(LoadError::InvariantBroken(format!( - "{buf_len} == {type_size} * {n_elements}" - ))); + TensorDataTreatment::SeekPast { n_bytes } => { + // skip if no buffer is given + reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?; } - reader.read_exact(buf)?; - } else { - // skip if no buffer is given - reader.seek(SeekFrom::Start( - offset_aligned + (type_size * n_elements) as u64, - ))?; } } diff --git a/ggml-loader/src/util.rs b/ggml-loader/src/util.rs index 61852fbf..92c2be40 100644 --- a/ggml-loader/src/util.rs +++ b/ggml-loader/src/util.rs @@ -62,16 +62,16 @@ pub fn decode_element_type_res(ftype: i32) -> Result(model_type: ControlFlow) -> Result> { - match model_type { +pub fn retchk(x: ControlFlow) -> Result> { + match x { ControlFlow::Continue(x) => Ok(x), ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), } } -pub fn brkchk(model_type: Result) -> ControlFlow { - match model_type { +pub fn brkchk>(x: Result) -> ControlFlow { + match x { Ok(x) => ControlFlow::Continue(x), - Err(y) => ControlFlow::Break(y), + Err(y) => ControlFlow::Break(y.into()), } } diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 7c9fe179..d31c4ec9 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -388,8 +388,8 @@ pub enum ElementType { F32, } impl From for llama_rs::ElementType { - fn from(model_type: ElementType) -> Self { - match model_type { + fn from(t: ElementType) -> Self { + match t { ElementType::Q4_0 => llama_rs::ElementType::Q4_0, ElementType::Q4_1 => llama_rs::ElementType::Q4_1, ElementType::F16 => llama_rs::ElementType::F16, diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index b3441be5..ae056605 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -17,7 +17,7 @@ use memmap2::Mmap; pub(crate) fn load( path: impl AsRef, - use_mmap: bool, + prefer_mmap: bool, n_context_tokens: usize, mut load_progress_callback: impl FnMut(LoadProgress), ) -> Result { @@ -126,7 +126,7 @@ pub(crate) fn load( let n_layer = hparams.n_layer; let n_vocab = hparams.n_vocab; - let alloc = !(use_mmap && model_type == ContainerType::GGJT); + let alloc = !(prefer_mmap && model_type.support_mmap()); let ctx_size = { // Use 64-bit math to prevent overflow. @@ -163,7 +163,7 @@ pub(crate) fn load( // Initialize the context let context = ggml::Context::init(ctx_size, alloc); - let (mmap, mmap_ptr) = if use_mmap && model_type == ContainerType::GGJT { + let (mmap, mmap_ptr) = if prefer_mmap && model_type.support_mmap() { let mmap = unsafe { Mmap::map(&file)? }; let ptr = mmap.as_ptr(); (Some(mmap), Some(ptr)) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index b7c9f365..e3ae6130 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -9,6 +9,7 @@ use ggml_loader::util::*; use ggml_loader::*; +use memmap2::Mmap; use std::{ fs::File, @@ -43,7 +44,7 @@ impl LoadError { pub(crate) fn load( path: impl AsRef, - use_mmap: bool, + prefer_mmap: bool, n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { @@ -62,8 +63,9 @@ pub(crate) fn load( model: None, n_ctx: n_context_tokens, load_progress_callback, - use_mmap, + preper_mmap: prefer_mmap, + tensor_accumulator: 0, hyperparameters: Hyperparameters::default(), container_type: ContainerType::GGJT, }; @@ -78,9 +80,10 @@ struct Loader { // input data and options path: PathBuf, n_ctx: usize, - use_mmap: bool, + preper_mmap: bool, // Internal state + tensor_accumulator: usize, container_type: ContainerType, hyperparameters: Hyperparameters, model: Option, @@ -108,8 +111,8 @@ impl ggml_loader::LoadHandler ControlFlow { - self.container_type = model_type; + fn got_container_type(&mut self, t: ContainerType) -> ControlFlow { + self.container_type = t; ControlFlow::Continue(()) } @@ -131,9 +134,9 @@ impl ggml_loader::LoadHandler ControlFlow> { + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow { if self.model.is_none() { - self.model = Some(self.create_model(self.vocab.clone())); + self.model = Some(brkchk(self.create_model(self.vocab.clone()))?); } let model = &mut self.model.as_mut().expect("initialized"); @@ -142,40 +145,59 @@ impl ggml_loader::LoadHandler return ControlFlow::Break(LoadError::InvalidUtf8(err)), }; - let tensor = match model.tensors_mut().get_mut(&tensor_name) { - Some(tensor) => tensor, + let tensor_count = model.tensors_mut().len(); + + // to satisfy borrow checker + macro_rules! get_tensor { + () => { + match model.tensors_mut().get_mut(&tensor_name) { + Some(tensor) => tensor, + None => { + return ControlFlow::Break(LoadError::UnknownTensor { + path: self.path.clone(), + tensor_name, + }) + } + } + }; + } + + // todo: support mmap + let ret = match &model.mmap { + Some(map) => unsafe { + let ptr = map.as_ptr().offset(info.start_offset as isize); + let tensor = get_tensor!(); + tensor.set_data(ptr as *mut std::ffi::c_void); + TensorDataTreatment::SeekPast { n_bytes: tensor.nbytes() } + }, None => { - return ControlFlow::Break(LoadError::UnknownTensor { - path: self.path.clone(), - tensor_name, - }) + let tensor = get_tensor!(); + let buf: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) + }; + TensorDataTreatment::CopyInto(buf) } }; - - // todo: support mmap - let buf: &mut [u8] = - unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; - - let tensor_count = model.tensors_mut().len(); (self.load_progress_callback)(LoadProgress::PartTensorLoaded { file: &self.path, // TODO: keep track of tensors loaded - current_tensor: 0, + current_tensor: self.tensor_accumulator, tensor_count, }); + self.tensor_accumulator += 1; - ControlFlow::Continue(Some(buf)) + ControlFlow::Continue(ret) } } impl Loader { - fn create_model(&mut self, vocabulary: Vocabulary) -> Model { + fn create_model(&mut self, vocabulary: Vocabulary) -> Result { (self.load_progress_callback)(LoadProgress::PartLoading { file: &self.path, current_part: 0, total_parts: 1, }); - let alloc = !(self.use_mmap && self.container_type == ContainerType::GGJT); + let alloc = !(self.use_mmap()); let Hyperparameters { n_vocab, n_embd, @@ -219,15 +241,27 @@ impl Loader { }; // Initialize the context let context = ggml::Context::init(ctx_size, alloc); - Model::new( - context, - self.hyperparameters, - vocabulary, - n_ff, - wtype, - self.container_type, - None, - ) + + let mmap = if self.container_type.support_mmap() { + let file = File::open(&self.path)?; + Some(unsafe { Mmap::map(&file)? }) + } else { + None + }; + + Ok(Model::new( + context, + self.hyperparameters, + vocabulary, + n_ff, + wtype, + self.container_type, + mmap, + )) + } + + fn use_mmap(&mut self) -> bool { + self.preper_mmap && self.container_type.support_mmap() } } diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index 0b547d2b..19e04144 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -25,7 +25,7 @@ pub struct Model { tensors: HashMap, /// Needs to kept alive while the model is alive - _mmap: Option, + pub(crate) mmap: Option, _version: ContainerType, @@ -39,7 +39,7 @@ impl Model { vocabulary: Vocabulary, n_ff: usize, wtype: ggml::Type, - model_type: ContainerType, + container_type: ContainerType, mmap: Option, ) -> Model { let n_embd = hparams.n_embd; @@ -110,8 +110,8 @@ impl Model { layers, tensors, _context: context, - _mmap: mmap, - _version: model_type, + mmap, + _version: container_type, } } @@ -120,16 +120,16 @@ impl Model { /// The status of the loading process will be reported through `load_progress_callback`. pub fn load( path: impl AsRef, - use_mmap: bool, + prefer_mmap: bool, n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { let use_loader_2: bool = std::env::var("USE_LOADER_2").is_ok(); if use_loader_2 { - loader2::load(path, use_mmap, n_context_tokens, load_progress_callback) + loader2::load(path, prefer_mmap, n_context_tokens, load_progress_callback) } else { - loader::load(path, use_mmap, n_context_tokens, load_progress_callback) + loader::load(path, prefer_mmap, n_context_tokens, load_progress_callback) } } From e4834bde2192c63f0619ca0db933df44ee387c87 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 20 Apr 2023 18:07:25 +0000 Subject: [PATCH 31/42] cargo fmt --- ggml-loader/src/lib.rs | 2 +- llama-rs/src/loader2.rs | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index 45b88058..c9f3d507 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -80,7 +80,7 @@ pub enum TensorDataTreatment<'a> { CopyInto(&'a mut [u8]), SeekPast { /// should be `tensor.nbytes` - n_bytes: usize + n_bytes: usize, }, } diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index e3ae6130..ba204c1e 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -31,9 +31,7 @@ impl LoadError { ggml_loader::LoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), ggml_loader::LoadError::UserInterrupted(err) => err, ggml_loader::LoadError::UnsupportedElementType(ty) => { - LoadError::HyperparametersF16Invalid { - ftype: ty, - } + LoadError::HyperparametersF16Invalid { ftype: ty } } ggml_loader::LoadError::InvariantBroken(invariant) => { LoadError::InvariantBroken { path, invariant } @@ -168,7 +166,9 @@ impl ggml_loader::LoadHandler { let tensor = get_tensor!(); @@ -250,14 +250,14 @@ impl Loader { }; Ok(Model::new( - context, - self.hyperparameters, - vocabulary, - n_ff, - wtype, - self.container_type, - mmap, - )) + context, + self.hyperparameters, + vocabulary, + n_ff, + wtype, + self.container_type, + mmap, + )) } fn use_mmap(&mut self) -> bool { From c380ceeefaf6d7892d25587c54ebd14264f23462 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 20 Apr 2023 18:09:49 +0000 Subject: [PATCH 32/42] Make loader2 default --- llama-rs/src/model.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index 19e04144..e9e3152e 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -124,7 +124,12 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { - let use_loader_2: bool = std::env::var("USE_LOADER_2").is_ok(); + let use_loader_2: bool = match std::env::var("GGML_LOADER").as_deref() { + Ok("2") => true, + Ok("1") => false, + Ok(_) => panic!("Please use GGML_LOADER=1 or GGML_LOADER=2"), + Err(_) => true, + }; if use_loader_2 { loader2::load(path, prefer_mmap, n_context_tokens, load_progress_callback) From 5b9788b5b47f27a908af37f6e26133ffcdba6979 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 17:18:45 +0200 Subject: [PATCH 33/42] fix: remove dbg!(start_pos) --- ggml-loader/src/lib.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index c9f3d507..8348f8fe 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -180,9 +180,6 @@ pub fn load_weights( ) -> Result<(), LoadError> { while has_data_left(reader)? { // load tensor header - let start_pos = reader.stream_position()?; - dbg!(start_pos); - let n_dims: usize = read_i32(reader)?.try_into()?; let name_len = read_i32(reader)?; let ftype = decode_element_type_res(read_i32(reader)?)?; From cbf0756329eef4e2474d93d7d6bf4f6372a2d590 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 17:18:58 +0200 Subject: [PATCH 34/42] fix: respect --no-mmap --- llama-rs/src/loader2.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index ba204c1e..76fcf751 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -61,7 +61,7 @@ pub(crate) fn load( model: None, n_ctx: n_context_tokens, load_progress_callback, - preper_mmap: prefer_mmap, + prefer_mmap, tensor_accumulator: 0, hyperparameters: Hyperparameters::default(), @@ -78,7 +78,7 @@ struct Loader { // input data and options path: PathBuf, n_ctx: usize, - preper_mmap: bool, + prefer_mmap: bool, // Internal state tensor_accumulator: usize, @@ -242,7 +242,7 @@ impl Loader { // Initialize the context let context = ggml::Context::init(ctx_size, alloc); - let mmap = if self.container_type.support_mmap() { + let mmap = if self.use_mmap() { let file = File::open(&self.path)?; Some(unsafe { Mmap::map(&file)? }) } else { @@ -261,7 +261,7 @@ impl Loader { } fn use_mmap(&mut self) -> bool { - self.preper_mmap && self.container_type.support_mmap() + self.prefer_mmap && self.container_type.support_mmap() } } From 430abfea2042dd6f51d4550c4d20280a145b1bbc Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 17:37:25 +0200 Subject: [PATCH 35/42] chore: remove old comments --- llama-rs/src/loader2.rs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 76fcf751..c498cda5 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -1,12 +1,3 @@ -//! This is an experimental, *incomplete* implementation of a loader based on `ggml_loader`. -//! -//! At the time of writing, it does not successfully load any models. -//! -//! GGML/GGMF fails with an invariant broken error, and GGJT fails with an unexpected state error. -//! -//! It also does not support mmap, but it shouldn't be too hard to add: mmap as is done in `loader`, then populate -//! the tensor from the [TensorInfo]. - use ggml_loader::util::*; use ggml_loader::*; use memmap2::Mmap; @@ -160,7 +151,6 @@ impl ggml_loader::LoadHandler unsafe { let ptr = map.as_ptr().offset(info.start_offset as isize); @@ -180,7 +170,6 @@ impl ggml_loader::LoadHandler Date: Sat, 22 Apr 2023 17:39:27 +0200 Subject: [PATCH 36/42] chore: remove unused error case --- llama-rs/src/loader_common.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index 1c264c88..07f60eb8 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -136,18 +136,6 @@ pub enum LoadError { /// The invariant that was broken. invariant: String, }, - /// The loader was in an unexpected state. - /// - /// This error is not relevant unless `loader2` is being used. - #[error("unexpected state {state:?} in {path:?}: {context}")] - UnexpectedState { - /// The path that failed. - path: PathBuf, - /// The state that was encountered. - state: UnexpectedState, - /// Context about what was expected. - context: String, - }, } impl From for LoadError { fn from(value: FindAllModelFilesError) -> Self { @@ -157,12 +145,3 @@ impl From for LoadError { } } } - -#[derive(Debug)] -/// The state that the loader was in when an error was encountered. -pub enum UnexpectedState { - /// The loader was in the `Vocabulary` state. - Vocabulary, - /// The loader was in the `Model` state. - Model, -} From 9b908ae57a77f0a134fd2720d66f024f06286ce4 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 17:54:04 +0200 Subject: [PATCH 37/42] fix: remove some panics --- llama-rs/src/lib.rs | 2 +- llama-rs/src/loader2.rs | 19 +++++++++++-------- llama-rs/src/loader_common.rs | 15 ++++++++++++++- llama-rs/src/vocabulary.rs | 14 ++++++++++++-- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 3f0a6c69..bc7cd919 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -22,7 +22,7 @@ pub use inference_session::{ pub use loader_common::{LoadError, LoadProgress}; pub use model::{Hyperparameters, Model}; pub use util::TokenUtf8Buffer; -pub use vocabulary::{TokenBias, TokenId, Vocabulary}; +pub use vocabulary::{AddTokenError, TokenBias, TokenId, Vocabulary}; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index c498cda5..919585f3 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -62,7 +62,7 @@ pub(crate) fn load( ggml_loader::load_model_from_reader(&mut reader, &mut loader) .map_err(|err| LoadError::from_ggml_loader_error(err, path.clone()))?; - Ok(loader.model.expect("model should be initialized")) + loader.model.ok_or(LoadError::ModelNotCreated { path }) } struct Loader { @@ -110,9 +110,9 @@ impl ggml_loader::LoadHandler id, Err(err) => return ControlFlow::Break(LoadError::InvalidIntegerConversion(err)), }; - self.vocab - .push_token(id, token, score) - .expect("vocab should be valid"); + if let Err(err) = self.vocab.push_token(id, token, score) { + return ControlFlow::Break(LoadError::from(err)); + } ControlFlow::Continue(()) } @@ -124,10 +124,13 @@ impl ggml_loader::LoadHandler ControlFlow { - if self.model.is_none() { - self.model = Some(brkchk(self.create_model(self.vocab.clone()))?); - } - let model = &mut self.model.as_mut().expect("initialized"); + let model = match &mut self.model { + Some(model) => model, + None => { + let model = brkchk(self.create_model(self.vocab.clone()))?; + self.model.get_or_insert(model) + } + }; let tensor_name = match String::from_utf8(info.name) { Ok(n) => n, diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index 07f60eb8..662e1713 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -2,7 +2,7 @@ use std::path::{Path, PathBuf}; use thiserror::Error; -use crate::{util::FindAllModelFilesError, Hyperparameters}; +use crate::{util::FindAllModelFilesError, vocabulary::AddTokenError, Hyperparameters}; /// Each variant represents a step within the process of loading the model. /// These can be used to report progress to the user. @@ -78,6 +78,9 @@ pub enum LoadError { #[error("invalid integer conversion")] /// One of the integers encountered could not be converted to a more appropriate type. InvalidIntegerConversion(#[from] std::num::TryFromIntError), + /// While loading, a token could not be added to the vocabulary. + #[error("failed to add token to vocabulary: {0}")] + VocabularyAddTokenFailed(#[from] AddTokenError), #[error("unsupported f16_: {0}")] /// One of the integers encountered could not be converted to a more appropriate type. UnsupportedElementType(i32), @@ -136,6 +139,16 @@ pub enum LoadError { /// The invariant that was broken. invariant: String, }, + /// The model could not be created. + /// + /// This implies that there were no tensors in the model to be loaded. + /// + /// This error is not relevant unless `loader2` is being used. + #[error("could not create model from {path:?}")] + ModelNotCreated { + /// The path that failed. + path: PathBuf, + }, } impl From for LoadError { fn from(value: FindAllModelFilesError) -> Self { diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index 974d24c0..21d6e5d2 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -27,14 +27,22 @@ pub struct Vocabulary { } #[derive(Debug, Clone, Error)] +/// Errors encountered when adding a token to a vocabulary. pub enum AddTokenError { #[error("the id of token added should be {expected_id}; is {actual_id}")] + /// The token that was added does not have the expected ID. WrongId { + /// The expected ID. expected_id: TokenId, + /// The actual ID. actual_id: TokenId, }, #[error("a token with the same id already exists, id={id}")] - AlreadyAdded { id: TokenId }, + /// A token with the same ID was already added. + AlreadyAdded { + /// The ID of the token that was already added. + id: TokenId, + }, } impl Vocabulary { @@ -57,7 +65,9 @@ impl Vocabulary { self.max_token_length = self.max_token_length.max(content.len()); self.id_to_token.push(content.clone()); self.id_to_token_score.push(score); - self.token_to_id.insert(content, id); + if self.token_to_id.insert(content, id).is_some() { + return Err(AddTokenError::AlreadyAdded { id }); + } Ok(()) } From d8c4ca699f349f53f6dc46954c584d57bf6bf8f3 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 18:04:06 +0200 Subject: [PATCH 38/42] feat: remove AlreadyAdded error Apparently some models just have token dupes? /shrug --- llama-rs/src/loader.rs | 32 ++++++++++---------------------- llama-rs/src/vocabulary.rs | 10 +--------- 2 files changed, 11 insertions(+), 31 deletions(-) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index ae056605..ac985fa3 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -84,37 +84,25 @@ pub(crate) fn load( // Load vocabulary // =============== let vocabulary = { - let mut id_to_token = vec![]; - let mut id_to_token_score = vec![]; - let mut token_to_id = HashMap::new(); - let mut max_token_length = 0; + let mut vocab = Vocabulary::default(); for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; + let id = i as TokenId; let token = read_bytes_with_len(&mut reader, len.try_into()?)?; - max_token_length = max_token_length.max(token.len()); - id_to_token.push(token.clone()); - token_to_id.insert(token, TokenId::try_from(i)?); - - // Token score, currently unused - match model_type { - ContainerType::GGMF | ContainerType::GGJT => { - let score = read_f32(&mut reader)?; - id_to_token_score.push(score); - } + + let score = match model_type { + ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?, ContainerType::GGML => { // Legacy model, set empty score - id_to_token_score.push(0.); + 0. } - } - } + }; - Vocabulary { - id_to_token, - id_to_token_score, - token_to_id, - max_token_length, + vocab.push_token(id, token, score)?; } + + vocab }; // for the big tensors, we have the option to store the data in 16-bit diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index 21d6e5d2..7bee66d2 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -37,12 +37,6 @@ pub enum AddTokenError { /// The actual ID. actual_id: TokenId, }, - #[error("a token with the same id already exists, id={id}")] - /// A token with the same ID was already added. - AlreadyAdded { - /// The ID of the token that was already added. - id: TokenId, - }, } impl Vocabulary { @@ -65,9 +59,7 @@ impl Vocabulary { self.max_token_length = self.max_token_length.max(content.len()); self.id_to_token.push(content.clone()); self.id_to_token_score.push(score); - if self.token_to_id.insert(content, id).is_some() { - return Err(AddTokenError::AlreadyAdded { id }); - } + self.token_to_id.insert(content, id); Ok(()) } From cabc4c93cc5b137309f6ac0220878504629102ce Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Sat, 22 Apr 2023 18:26:21 +0000 Subject: [PATCH 39/42] minor fix --- llama-rs/src/loader2.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 919585f3..2c491f0b 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -128,7 +128,7 @@ impl ggml_loader::LoadHandler model, None => { let model = brkchk(self.create_model(self.vocab.clone()))?; - self.model.get_or_insert(model) + self.model.insert(model) } }; From 1930496c7fb9f0ab6cb4391d1579c108fa1e0773 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 20:33:41 +0200 Subject: [PATCH 40/42] fix: Vocabulary::push_token is infallible --- llama-cli/src/cli_args.rs | 2 +- llama-rs/src/lib.rs | 2 +- llama-rs/src/loader.rs | 2 +- llama-rs/src/loader2.rs | 4 +--- llama-rs/src/loader_common.rs | 5 +---- llama-rs/src/vocabulary.rs | 32 ++++++-------------------------- 6 files changed, 11 insertions(+), 36 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 4c4e3fe4..e31d4f48 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -288,7 +288,7 @@ impl ModelLoad { } => { let current_part = current_part + 1; log::info!( - "Loading model part {}/{} from '{}' (mmap: {})\n", + "Loading model part {}/{} from '{}' (mmap preferred: {})\n", current_part, total_parts, file.to_string_lossy(), diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index bc7cd919..3f0a6c69 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -22,7 +22,7 @@ pub use inference_session::{ pub use loader_common::{LoadError, LoadProgress}; pub use model::{Hyperparameters, Model}; pub use util::TokenUtf8Buffer; -pub use vocabulary::{AddTokenError, TokenBias, TokenId, Vocabulary}; +pub use vocabulary::{TokenBias, TokenId, Vocabulary}; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index ac985fa3..9ef545e5 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -99,7 +99,7 @@ pub(crate) fn load( } }; - vocab.push_token(id, token, score)?; + vocab.push_token(id, token, score); } vocab diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 2c491f0b..eaa622cb 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -110,9 +110,7 @@ impl ggml_loader::LoadHandler id, Err(err) => return ControlFlow::Break(LoadError::InvalidIntegerConversion(err)), }; - if let Err(err) = self.vocab.push_token(id, token, score) { - return ControlFlow::Break(LoadError::from(err)); - } + self.vocab.push_token(id, token, score); ControlFlow::Continue(()) } diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index 662e1713..2cb1ae1f 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -2,7 +2,7 @@ use std::path::{Path, PathBuf}; use thiserror::Error; -use crate::{util::FindAllModelFilesError, vocabulary::AddTokenError, Hyperparameters}; +use crate::{util::FindAllModelFilesError, Hyperparameters}; /// Each variant represents a step within the process of loading the model. /// These can be used to report progress to the user. @@ -78,9 +78,6 @@ pub enum LoadError { #[error("invalid integer conversion")] /// One of the integers encountered could not be converted to a more appropriate type. InvalidIntegerConversion(#[from] std::num::TryFromIntError), - /// While loading, a token could not be added to the vocabulary. - #[error("failed to add token to vocabulary: {0}")] - VocabularyAddTokenFailed(#[from] AddTokenError), #[error("unsupported f16_: {0}")] /// One of the integers encountered could not be converted to a more appropriate type. UnsupportedElementType(i32), diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index 7bee66d2..20fa4a5c 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -1,7 +1,5 @@ use std::{collections::HashMap, str::FromStr}; -use thiserror::Error; - use crate::InferenceError; /// The identifier of a token in a vocabulary. @@ -26,41 +24,23 @@ pub struct Vocabulary { pub(crate) max_token_length: usize, } -#[derive(Debug, Clone, Error)] -/// Errors encountered when adding a token to a vocabulary. -pub enum AddTokenError { - #[error("the id of token added should be {expected_id}; is {actual_id}")] - /// The token that was added does not have the expected ID. - WrongId { - /// The expected ID. - expected_id: TokenId, - /// The actual ID. - actual_id: TokenId, - }, -} - impl Vocabulary { /// Add a token to the vocabulary. /// /// The token added must have `id` directly after the last token in the vocabulary. - pub fn push_token( - &mut self, - id: TokenId, - content: Token, - score: TokenScore, - ) -> Result<(), AddTokenError> { + pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { + // These are loader invariants. If this is broken, then the loader is broken and this is a bug, + // not an issue with the model itself. assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize { - return Err(AddTokenError::WrongId { - expected_id: self.id_to_token.len() as TokenId, - actual_id: id, - }); + let expected_id = self.id_to_token.len() as TokenId; + panic!("the id of token added should be {expected_id}; is {id}"); } + self.max_token_length = self.max_token_length.max(content.len()); self.id_to_token.push(content.clone()); self.id_to_token_score.push(score); self.token_to_id.insert(content, id); - Ok(()) } pub(crate) fn token(&self, idx: usize) -> &[u8] { From bdb9856c4a99853e8477835d0ea8edf3b64d858a Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 20:56:43 +0200 Subject: [PATCH 41/42] fix: bail on multipart models with loader2 --- Cargo.lock | 1 - ggml-loader/src/lib.rs | 19 +++++-------------- ggml-loader/src/util.rs | 4 ++-- llama-rs/Cargo.toml | 1 - llama-rs/src/loader2.rs | 18 ++++++++++-------- llama-rs/src/loader_common.rs | 8 ++++++++ llama-rs/src/model.rs | 3 +++ 7 files changed, 28 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 186da4eb..eff66fa1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -615,7 +615,6 @@ dependencies = [ "bytemuck", "ggml", "ggml-loader", - "log", "memmap2", "partial_sort", "protobuf", diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs index 8348f8fe..7d29d4b3 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-loader/src/lib.rs @@ -96,12 +96,6 @@ pub trait LoadHandler { fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow; - /// multi-file loading is not supported - /// To handle that yourself, return [`ControlFlow::Break(_)`] here - fn load_multipart(&mut self, reader: &mut R) -> ControlFlow { - ControlFlow::Continue(()) - } - /// callback to get tensor buffer to populate /// /// # Returns @@ -128,7 +122,7 @@ pub fn load_model_from_reader( ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, magic => return Err(LoadError::InvalidMagic(magic)), }; - retchk(handler.got_container_type(container_type))?; + controlflow_to_result(handler.got_container_type(container_type))?; // Load format version match container_type { @@ -142,7 +136,7 @@ pub fn load_model_from_reader( } // Load hyper params - let hparams = retchk(handler.load_hyper_parameters(reader))?; + let hparams = controlflow_to_result(handler.load_hyper_parameters(reader))?; let n_vocab = hparams.n_vocab; // Load vocabulary @@ -156,15 +150,12 @@ pub fn load_model_from_reader( 0. } }; - retchk(handler.got_vocab_token(i, token, token_score))?; + controlflow_to_result(handler.got_vocab_token(i, token, token_score))?; } // Load tensor data match container_type { - ContainerType::GGMF | ContainerType::GGML => { - retchk(handler.load_multipart(reader))?; - load_weights(reader, handler, false) - } + ContainerType::GGMF | ContainerType::GGML => load_weights(reader, handler, false), ContainerType::GGJT => load_weights(reader, handler, true), } } @@ -227,7 +218,7 @@ pub fn load_weights( start_offset: offset_aligned, }; - match retchk(handler.tensor_buffer(tensor_info))? { + match controlflow_to_result(handler.tensor_buffer(tensor_info))? { TensorDataTreatment::CopyInto(buf) => { if align { reader.seek(SeekFrom::Start(offset_aligned))?; diff --git a/ggml-loader/src/util.rs b/ggml-loader/src/util.rs index 92c2be40..33374fd6 100644 --- a/ggml-loader/src/util.rs +++ b/ggml-loader/src/util.rs @@ -62,14 +62,14 @@ pub fn decode_element_type_res(ftype: i32) -> Result(x: ControlFlow) -> Result> { +pub fn controlflow_to_result(x: ControlFlow) -> Result> { match x { ControlFlow::Continue(x) => Ok(x), ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), } } -pub fn brkchk>(x: Result) -> ControlFlow { +pub fn result_to_controlflow>(x: Result) -> ControlFlow { match x { Ok(x) => ControlFlow::Continue(x), Err(y) => ControlFlow::Break(y.into()), diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 6fecd50d..7ed254a4 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -23,7 +23,6 @@ memmap2 = "0.5.10" serde_json = { version = "1.0", optional = true } protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } -log = "*" [features] convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index eaa622cb..ff84e3b1 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -9,7 +9,10 @@ use std::{ path::{Path, PathBuf}, }; -use crate::{util::mulf, Hyperparameters, LoadError, LoadProgress, Model, TokenId, Vocabulary}; +use crate::{ + util::{self, mulf}, + Hyperparameters, LoadError, LoadProgress, Model, TokenId, Vocabulary, +}; impl LoadError { fn from_ggml_loader_error(value: ggml_loader::LoadError, path: PathBuf) -> Self { @@ -39,6 +42,11 @@ pub(crate) fn load( ) -> Result { let main_path = path.as_ref(); + let paths = util::find_all_model_files(main_path)?; + if paths.len() != 1 { + return Err(LoadError::MultipartNotSupported { paths }); + } + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { source: e, path: main_path.to_owned(), @@ -115,17 +123,11 @@ impl ggml_loader::LoadHandler) -> ControlFlow { - // todo - log::warn!("multipart model is not supported"); - ControlFlow::Continue(()) - } - fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow { let model = match &mut self.model { Some(model) => model, None => { - let model = brkchk(self.create_model(self.vocab.clone()))?; + let model = result_to_controlflow(self.create_model(self.vocab.clone()))?; self.model.insert(model) } }; diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index 2cb1ae1f..4a219642 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -146,6 +146,14 @@ pub enum LoadError { /// The path that failed. path: PathBuf, }, + /// Multiple parts of the model were found. + /// + /// Multi-part models are not supported. Please convert the model to a single part. + #[error("multipart models are not supported")] + MultipartNotSupported { + /// The paths that were found. + paths: Vec, + }, } impl From for LoadError { fn from(value: FindAllModelFilesError) -> Self { diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index e9e3152e..6cd64dc1 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -124,6 +124,9 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { + // Loader2 is the default. It can support GGML, GGMF and GGJT, but does not support multipart models. + // + // Loader1 is the old loader. It can support multipart models, but will be deprecated. let use_loader_2: bool = match std::env::var("GGML_LOADER").as_deref() { Ok("2") => true, Ok("1") => false, From b41fe14d2c364c4bb85aefae2661eb0a02c57484 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 22 Apr 2023 22:33:53 +0200 Subject: [PATCH 42/42] refactor: make Vocabulary::push_token pub(crate) --- llama-rs/src/vocabulary.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index 20fa4a5c..32bdd07f 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -28,7 +28,11 @@ impl Vocabulary { /// Add a token to the vocabulary. /// /// The token added must have `id` directly after the last token in the vocabulary. - pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { + /// + /// # Panics + /// - This function can panic if `id` does not correspond to the next token in the vocabulary. + /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. + pub(crate) fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { // These are loader invariants. If this is broken, then the loader is broken and this is a bug, // not an issue with the model itself. assert_eq!(self.id_to_token.len(), self.id_to_token_score.len());