Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Move loader to standalone crate llama-loader
Browse files Browse the repository at this point in the history
  • Loading branch information
iacore committed Apr 8, 2023
1 parent 0c45c71 commit 9f625c0
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 88 deletions.
9 changes: 9 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
members = [
"ggml-sys",
"ggml",
"llama-loader",
"llama-rs",
"llama-cli",
"generate-ggml-bindings"
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ impl Tensor {
// /// Set the tensor's data pointer (useful for mmap-ed data)
// ///
// /// # Safety
// ///
// ///
// /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from.
// pub unsafe fn set_data(&self, data_ptr: *mut c_void) {
// self.with_alive_ctx(|| {
Expand Down
10 changes: 10 additions & 0 deletions llama-loader/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "llama-loader"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ggml = { path = "../ggml" }
thiserror = "*"
100 changes: 46 additions & 54 deletions llama-rs/src/loader2.rs → llama-loader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
#![allow(missing_docs)]
//! standalone model loader
//!
//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM.
#![allow(clippy::nonminimal_bool)]

//! standalone model loader
pub mod util;

use std::{
io::{BufRead, Seek, SeekFrom},
ops::ControlFlow,
};
use util::*;

pub type ElementType = ggml::Type;

/// file type containing the model
#[derive(Debug, PartialEq, Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
pub enum ContainerType {
/// legacy format, oldest ggml tensor file format
GGML,
/// also legacy format, newer than GGML, older than GGJT
GGMF,
/// mmap-able format
GGJT,
}

use crate::{loader::has_data_left, ElementType, ModelContainerType};

pub(crate) fn decode_element_type(ftype: i32) -> Option<ElementType> {
pub fn decode_element_type(ftype: i32) -> Option<ElementType> {
match ftype {
0 => Some(ggml::Type::F32),
1 => Some(ggml::Type::F16),
Expand All @@ -19,7 +35,7 @@ pub(crate) fn decode_element_type(ftype: i32) -> Option<ElementType> {
}
}

pub(crate) fn encode_element_type(element_type: ElementType) -> Option<i32> {
pub fn encode_element_type(element_type: ElementType) -> Option<i32> {
match element_type {
ggml::Type::F32 => Some(0),
ggml::Type::F16 => Some(1),
Expand All @@ -29,38 +45,9 @@ pub(crate) fn encode_element_type(element_type: ElementType) -> Option<i32> {
}
}

pub(crate) fn read_bytes<const N: usize>(
reader: &mut impl BufRead,
) -> Result<[u8; N], std::io::Error> {
let mut bytes = [0u8; N];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result<i32, std::io::Error> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result<u32, std::io::Error> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result<f32, std::io::Error> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, std::io::Error> {
let mut bytes = vec![0u8; len];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

/// The hyperparameters of the model.
#[derive(Debug, Clone)]
pub struct FixedHyperparameters {
pub struct LlamaHyperparameters {
pub n_vocab: usize,
pub n_embd: usize,
pub n_mult: usize,
Expand Down Expand Up @@ -90,7 +77,7 @@ pub enum LoadError<T> {

#[error("unsupported tensor dtype/f16_: {0}")]
UnsupportedElementtype(i32),

/// sanity check failed
#[error("invariant broken: {0}")]
InvariantBroken(String),
Expand All @@ -107,11 +94,11 @@ pub struct TensorInfo {

#[allow(unused_variables)]
pub trait LoadHandler<T> {
fn cb_container_type(&mut self, model_type: ModelContainerType) -> ControlFlow<T> {
fn cb_container_type(&mut self, model_type: ContainerType) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn cb_hyper_parameters(&mut self, hparams: FixedHyperparameters) -> ControlFlow<T> {
fn cb_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow<T> {
ControlFlow::Continue(())
}

Expand All @@ -134,30 +121,30 @@ pub fn load_model_from_reader<T>(
handler: &mut impl LoadHandler<T>,
) -> Result<(), LoadError<T>> {
// Verify magic
let container_type: ModelContainerType = match read_u32(&mut reader)? {
ggml::FILE_MAGIC_GGMF => ModelContainerType::GGMF,
ggml::FILE_MAGIC_GGJT => ModelContainerType::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ModelContainerType::Unversioned,
let container_type: ContainerType = match read_u32(&mut reader)? {
ggml::FILE_MAGIC_GGMF => ContainerType::GGMF,
ggml::FILE_MAGIC_GGJT => ContainerType::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML,
magic => return Err(LoadError::InvalidMagic(magic)),
};
retchk(handler.cb_container_type(container_type))?;

// Load format version
match container_type {
ModelContainerType::GGMF | ModelContainerType::GGJT => {
ContainerType::GGMF | ContainerType::GGJT => {
let _version: u32 = match read_u32(&mut reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion(version)),
};
}
ModelContainerType::Unversioned => {}
ContainerType::GGML => {}
}

// Load hyper params
//
// NOTE: Field order matters! Data is laid out in the file exactly
// in this order.
let hparams = FixedHyperparameters {
let hparams = LlamaHyperparameters {
n_vocab: read_i32(&mut reader)?.try_into()?,
n_embd: read_i32(&mut reader)?.try_into()?,
n_mult: read_i32(&mut reader)?.try_into()?,
Expand All @@ -174,8 +161,8 @@ pub fn load_model_from_reader<T>(
let len = read_u32(&mut reader)?.try_into()?;
let token = read_bytes_with_len(&mut reader, len)?;
let token_score = match container_type {
ModelContainerType::GGMF | ModelContainerType::GGJT => read_f32(&mut reader)?,
ModelContainerType::Unversioned => {
ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?,
ContainerType::GGML => {
// Legacy model, set empty score
0.
}
Expand All @@ -185,12 +172,12 @@ pub fn load_model_from_reader<T>(

// Load tensor data
match container_type {
ModelContainerType::GGMF | ModelContainerType::Unversioned => {
ContainerType::GGMF | ContainerType::GGML => {
let _file_offset = reader.stream_position()?;
drop(reader);
todo!()
}
ModelContainerType::GGJT => load_weights_ggjt(&mut reader, handler),
ContainerType::GGJT => load_weights_ggjt(&mut reader, handler),
}
}

Expand Down Expand Up @@ -238,23 +225,28 @@ fn load_weights_ggjt<T>(
}

let tensor_info = TensorInfo {
name, dims, n_dims, n_elements, ftype,
name,
dims,
n_dims,
n_elements,
ftype,
};

// load tensor weights
let offset_curr = reader.stream_position()?;
let offset_aligned: u64 = (offset_curr + 31) & !31;
reader.seek(SeekFrom::Start(offset_aligned))?;

let type_size = ggml::type_size(ftype);
let buf = retchk(handler.tensor_buffer(tensor_info))?;
let buf_len = buf.len();
if !(buf_len == type_size * n_elements) {
return Err(LoadError::InvariantBroken(format!("{buf_len} == {type_size} * {n_elements}")));
return Err(LoadError::InvariantBroken(format!(
"{buf_len} == {type_size} * {n_elements}"
)));
}
reader.read_exact(buf)?;
}

Ok(())
}

33 changes: 33 additions & 0 deletions llama-loader/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::io::BufRead;

pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> {
let mut bytes = [0u8; N];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

pub fn read_i32(reader: &mut impl BufRead) -> Result<i32, std::io::Error> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_u32(reader: &mut impl BufRead) -> Result<u32, std::io::Error> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_f32(reader: &mut impl BufRead) -> Result<f32, std::io::Error> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, std::io::Error> {
let mut bytes = vec![0u8; len];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

// NOTE: Implementation from #![feature(buf_read_has_data_left)]
pub fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error> {
reader.fill_buf().map(|b| !b.is_empty())
}
1 change: 1 addition & 0 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ rust-version = "1.65"

[dependencies]
ggml = { path = "../ggml" }
llama-loader = { path = "../llama-loader" }

bytemuck = "1.13.1"
partial_sort = "0.2.0"
Expand Down
3 changes: 2 additions & 1 deletion llama-rs/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::{
vec,
};

use crate::{util, Hyperparameters, Vocabulary, loader2::encode_element_type};
use crate::{util, Hyperparameters, Vocabulary};
use llama_loader::encode_element_type;

/// Converts a `pth` file to a `ggml` file.
pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) {
Expand Down
Loading

0 comments on commit 9f625c0

Please sign in to comment.