Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Standalone loader #125

Merged
merged 46 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
bdbea68
Add loader stub for GGJT
iacore Apr 6, 2023
b0a666f
Add loading code for ggjt
iacore Apr 6, 2023
9eefdc5
code cleanup that doesn't change anything
iacore Apr 6, 2023
c212c53
more code cleanup
iacore Apr 6, 2023
bfaec3a
minor change
iacore Apr 7, 2023
b6044ee
Add non-mmap loader for GGJT
iacore Apr 7, 2023
1872dda
Prefer traits in loader.rs
iacore Apr 7, 2023
ec1fca7
cargo fmt
iacore Apr 7, 2023
cc846ae
cargo clippy --fix
iacore Apr 7, 2023
bf847dd
Remove ggml::Tensor::set_data
iacore Apr 7, 2023
ea7094c
fix(llama): buffer tokens until valid UTF-8
philpax Apr 7, 2023
c848d5e
Add standalone loader
iacore Apr 8, 2023
8390593
Move loader to standalone crate llama-loader
iacore Apr 8, 2023
15fe19b
[llama-loader] Support non-copy loader
iacore Apr 8, 2023
2e9311d
Use functions from the new crate
iacore Apr 8, 2023
4dd0fc5
Merge branch 'main' into llama-loader
philpax Apr 13, 2023
c40e36e
Merge branch 'main' of github.com:rustformers/llama-rs into llama-loader
philpax Apr 13, 2023
34429e0
refactor(llama): pass mut tensors down
philpax Apr 13, 2023
38e7d58
feat/loader Make hparams configurable
iacore Apr 14, 2023
5dfc55d
feat/loader Add hook to support multi-part model loading
iacore Apr 14, 2023
48efd74
rename llama-loader to ggml-loader
iacore Apr 14, 2023
0fbbedd
Merge branch 'main' into llama-loader
philpax Apr 19, 2023
d65996d
fix
jon-chuang Apr 12, 2023
267d8ae
no_alloc
jon-chuang Apr 12, 2023
81a6979
chore: fix clippy
philpax Apr 19, 2023
80d189e
refactor(util): make find_all_model_files error
philpax Apr 19, 2023
85e1148
UnsupportedElementtype -> UnsupportedElementType
philpax Apr 19, 2023
3f29992
feat: experimental loader2 wire-up (incomplete)
philpax Apr 19, 2023
94951c4
fix dead doc link
philpax Apr 19, 2023
69f355b
feat: turn mmap on by default, add --no-mmap
philpax Apr 19, 2023
17bc0cc
Fix loading GGJT
iacore Apr 20, 2023
6641ae9
minor fix
iacore Apr 20, 2023
3910b6a
Add mmap
iacore Apr 20, 2023
e4834bd
cargo fmt
iacore Apr 20, 2023
c380cee
Make loader2 default
iacore Apr 20, 2023
5b9788b
fix: remove dbg!(start_pos)
philpax Apr 22, 2023
cbf0756
fix: respect --no-mmap
philpax Apr 22, 2023
8813b0f
Merge branch 'main' of github.com:rustformers/llama-rs into llama-loader
philpax Apr 22, 2023
430abfe
chore: remove old comments
philpax Apr 22, 2023
bf6a917
chore: remove unused error case
philpax Apr 22, 2023
9b908ae
fix: remove some panics
philpax Apr 22, 2023
d8c4ca6
feat: remove AlreadyAdded error
philpax Apr 22, 2023
cabc4c9
minor fix
iacore Apr 22, 2023
1930496
fix: Vocabulary::push_token is infallible
philpax Apr 22, 2023
bdb9856
fix: bail on multipart models with loader2
philpax Apr 22, 2023
b41fe14
refactor: make Vocabulary::push_token pub(crate)
philpax Apr 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 309 additions & 103 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
members = [
"ggml-sys",
"ggml",
"ggml-loader",
"llama-rs",
"llama-cli",
"generate-ggml-bindings"
Expand Down
10 changes: 10 additions & 0 deletions ggml-loader/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "ggml-loader"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ggml = { path = "../ggml" }
thiserror = "*"
236 changes: 236 additions & 0 deletions ggml-loader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//! standalone model loader
//!
//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM.
#![allow(clippy::nonminimal_bool)]

pub mod util;

use std::ops::ControlFlow;
use util::*;

pub type ElementType = ggml::Type;

/// file type containing the model
#[derive(Debug, PartialEq, Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
pub enum ContainerType {
/// legacy format, oldest ggml tensor file format
GGML,
/// also legacy format, newer than GGML, older than GGJT
GGMF,
/// mmap-able format
GGJT,
}

impl ContainerType {
pub fn support_mmap(&self) -> bool {
match self {
ContainerType::GGML => false,
ContainerType::GGMF => false,
ContainerType::GGJT => true,
}
}
}

#[derive(Debug, thiserror::Error)]
pub enum LoadError<T> {
#[error("invalid file magic number: {0}")]
InvalidMagic(u32),

#[error("invalid ggml format: version={0}")]
InvalidFormatVersion(u32),

#[error("{0}")]
Io(#[from] std::io::Error),

#[error("{0}")]
FailedCast(#[from] std::num::TryFromIntError),

/// return `ControlFlow::Break` from any of the `cb_*` function to trigger this error
#[error("user requested interrupt: {0}")]
UserInterrupted(T),

#[error("unsupported tensor dtype/f16_: {0}")]
UnsupportedElementType(i32),

/// sanity check failed
#[error("invariant broken: {0}")]
InvariantBroken(String),
}

#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: Vec<u8>,
pub n_dims: usize,
pub dims: [usize; 2],
pub n_elements: usize,
pub ftype: ElementType,
/// start of tensor - start of file
pub start_offset: u64,
}

/// Info in hyperparameter used for later loading tasks. Used in callback.
/// see [`LoadHandler::load_hyper_parameters`]
#[derive(Debug, Clone)]
pub struct PartialHyperparameters {
pub n_vocab: usize,
}

pub enum TensorDataTreatment<'a> {
CopyInto(&'a mut [u8]),
SeekPast {
/// should be `tensor.nbytes`
n_bytes: usize,
},
}

#[allow(unused_variables)]
pub trait LoadHandler<T, R: BufRead + Seek> {
fn got_container_type(&mut self, container_type: ContainerType) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn got_vocab_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow<T, PartialHyperparameters>;

/// callback to get tensor buffer to populate
///
/// # Returns
///
/// `None` to skip copying
/// `Some(buf)` to provide a buffer for copying weights into
fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow<T, TensorDataTreatment>;
}

#[test]
fn can_be_vtable() {
use std::mem::MaybeUninit;
let _a: MaybeUninit<Box<dyn LoadHandler<(), std::fs::File>>> = MaybeUninit::uninit();
}

pub fn load_model_from_reader<T, R: BufRead + Seek>(
reader: &mut R,
handler: &mut impl LoadHandler<T, R>,
) -> Result<(), LoadError<T>> {
// Verify magic
let container_type: ContainerType = match read_u32(reader)? {
ggml::FILE_MAGIC_GGMF => ContainerType::GGMF,
ggml::FILE_MAGIC_GGJT => ContainerType::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML,
magic => return Err(LoadError::InvalidMagic(magic)),
};
controlflow_to_result(handler.got_container_type(container_type))?;

// Load format version
match container_type {
ContainerType::GGMF | ContainerType::GGJT => {
let _version: u32 = match read_u32(reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion(version)),
};
}
ContainerType::GGML => {}
}

// Load hyper params
let hparams = controlflow_to_result(handler.load_hyper_parameters(reader))?;
let n_vocab = hparams.n_vocab;

// Load vocabulary
for i in 0..n_vocab {
let len = read_u32(reader)?.try_into()?;
let token = read_bytes_with_len(reader, len)?;
let token_score = match container_type {
ContainerType::GGMF | ContainerType::GGJT => read_f32(reader)?,
ContainerType::GGML => {
// Legacy model, set empty score
0.
}
};
controlflow_to_result(handler.got_vocab_token(i, token, token_score))?;
}

// Load tensor data
match container_type {
ContainerType::GGMF | ContainerType::GGML => load_weights(reader, handler, false),
ContainerType::GGJT => load_weights(reader, handler, true),
}
}

/// # Params
///
/// `align`
/// align to 4 bytes before reading tensor weights
pub fn load_weights<T, R: BufRead + Seek>(
reader: &mut R,
handler: &mut impl LoadHandler<T, R>,
align: bool,
) -> Result<(), LoadError<T>> {
while has_data_left(reader)? {
// load tensor header
let n_dims: usize = read_i32(reader)?.try_into()?;
let name_len = read_i32(reader)?;
let ftype = decode_element_type_res(read_i32(reader)?)?;

let mut n_elements: usize = 1;
let mut dims = [1usize, 1];
let ne_len = dims.len();
if !(n_dims <= ne_len) {
return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}")));
}
#[allow(clippy::needless_range_loop)]
for i in 0..n_dims {
let dim: usize = read_i32(reader)?.try_into()?;
dims[i] = dim;
n_elements *= dim;
}

// load tensor name
let name = read_bytes_with_len(reader, name_len.try_into()?)?;

// sanity check
match ftype {
ElementType::Q4_0 | ElementType::Q4_1 => {
if !(dims[0] % 64 == 0) {
return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0")));
}
}
_ => {}
}

// load tensor weights
let offset_curr = reader.stream_position()?;
let offset_aligned: u64 = if align {
(offset_curr + 31) & !31
} else {
offset_curr
};

let tensor_info = TensorInfo {
name,
dims,
n_dims,
n_elements,
ftype,
start_offset: offset_aligned,
};

match controlflow_to_result(handler.tensor_buffer(tensor_info))? {
TensorDataTreatment::CopyInto(buf) => {
if align {
reader.seek(SeekFrom::Start(offset_aligned))?;
}
reader.read_exact(buf)?;
}
TensorDataTreatment::SeekPast { n_bytes } => {
// skip if no buffer is given
reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?;
}
}
}

Ok(())
}
77 changes: 77 additions & 0 deletions ggml-loader/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
pub use std::io::{BufRead, Seek, SeekFrom};
use std::ops::ControlFlow;

use crate::{ElementType, LoadError};

pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> {
let mut bytes = [0u8; N];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

pub fn read_i32(reader: &mut impl BufRead) -> Result<i32, std::io::Error> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_u32(reader: &mut impl BufRead) -> Result<u32, std::io::Error> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_f32(reader: &mut impl BufRead) -> Result<f32, std::io::Error> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, std::io::Error> {
let mut bytes = vec![0u8; len];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

// NOTE: Implementation from #![feature(buf_read_has_data_left)]
pub fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error> {
reader.fill_buf().map(|b| !b.is_empty())
}

pub fn decode_element_type(ftype: i32) -> Option<ElementType> {
match ftype {
0 => Some(ggml::Type::F32),
1 => Some(ggml::Type::F16),
2 => Some(ggml::Type::Q4_0),
3 => Some(ggml::Type::Q4_1),
_ => None,
}
}

pub fn encode_element_type(element_type: ElementType) -> Option<i32> {
match element_type {
ggml::Type::F32 => Some(0),
ggml::Type::F16 => Some(1),
ggml::Type::Q4_0 => Some(2),
ggml::Type::Q4_1 => Some(3),
_ => None,
}
}

pub fn decode_element_type_res<T>(ftype: i32) -> Result<ElementType, LoadError<T>> {
match decode_element_type(ftype) {
Some(x) => Ok(x),
None => Err(LoadError::UnsupportedElementType(ftype)),
}
}

pub fn controlflow_to_result<A, B>(x: ControlFlow<A, B>) -> Result<B, LoadError<A>> {
match x {
ControlFlow::Continue(x) => Ok(x),
ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)),
}
}

pub fn result_to_controlflow<A, B, C: Into<A>>(x: Result<B, C>) -> ControlFlow<A, B> {
match x {
Ok(x) => ControlFlow::Continue(x),
Err(y) => ControlFlow::Break(y.into()),
}
}
Loading