Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #152 from philpax/load-tensors-as-stored
Browse files Browse the repository at this point in the history
fix #149 - load tensors by type, ignoring filetype
  • Loading branch information
philpax authored Apr 25, 2023
2 parents 1b20306 + c9e5c26 commit 8254deb
Show file tree
Hide file tree
Showing 13 changed files with 402 additions and 241 deletions.
2 changes: 1 addition & 1 deletion ggml-loader/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ edition = "2021"

[dependencies]
ggml = { path = "../ggml" }
thiserror = "*"
thiserror = "1.0"
28 changes: 18 additions & 10 deletions ggml-loader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use util::*;

pub type ElementType = ggml::Type;

/// file type containing the model
/// the format of the file containing the model
#[derive(Debug, PartialEq, Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
pub enum ContainerType {
Expand All @@ -21,7 +21,6 @@ pub enum ContainerType {
/// mmap-able format
GGJT,
}

impl ContainerType {
pub fn support_mmap(&self) -> bool {
match self {
Expand Down Expand Up @@ -64,10 +63,19 @@ pub struct TensorInfo {
pub n_dims: usize,
pub dims: [usize; 2],
pub n_elements: usize,
pub ftype: ElementType,
pub element_type: ElementType,
/// start of tensor - start of file
pub start_offset: u64,
}
impl TensorInfo {
pub fn calc_size(&self) -> usize {
let mut size = ggml::type_size(self.element_type);
for &dim in &self.dims[0..self.n_dims] {
size *= dim;
}
size / ggml::blck_size(self.element_type)
}
}

/// Info in hyperparameter used for later loading tasks. Used in callback.
/// see [`LoadHandler::load_hyper_parameters`]
Expand All @@ -78,10 +86,7 @@ pub struct PartialHyperparameters {

pub enum TensorDataTreatment<'a> {
CopyInto(&'a mut [u8]),
SeekPast {
/// should be `tensor.nbytes`
n_bytes: usize,
},
Skip,
}

#[allow(unused_variables)]
Expand Down Expand Up @@ -173,7 +178,9 @@ pub fn load_weights<T, R: BufRead + Seek>(
// load tensor header
let n_dims: usize = read_i32(reader)?.try_into()?;
let name_len = read_i32(reader)?;
let ftype = decode_element_type_res(read_i32(reader)?)?;
let ftype = read_i32(reader)?;
let ftype =
ggml::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType(ftype))?;

let mut n_elements: usize = 1;
let mut dims = [1usize, 1];
Expand Down Expand Up @@ -214,9 +221,10 @@ pub fn load_weights<T, R: BufRead + Seek>(
dims,
n_dims,
n_elements,
ftype,
element_type: ftype,
start_offset: offset_aligned,
};
let n_bytes = tensor_info.calc_size();

match controlflow_to_result(handler.tensor_buffer(tensor_info))? {
TensorDataTreatment::CopyInto(buf) => {
Expand All @@ -225,7 +233,7 @@ pub fn load_weights<T, R: BufRead + Seek>(
}
reader.read_exact(buf)?;
}
TensorDataTreatment::SeekPast { n_bytes } => {
TensorDataTreatment::Skip => {
// skip if no buffer is given
reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?;
}
Expand Down
29 changes: 1 addition & 28 deletions ggml-loader/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub use std::io::{BufRead, Seek, SeekFrom};
use std::ops::ControlFlow;

use crate::{ElementType, LoadError};
use crate::LoadError;

pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> {
let mut bytes = [0u8; N];
Expand Down Expand Up @@ -35,33 +35,6 @@ pub fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error>
reader.fill_buf().map(|b| !b.is_empty())
}

pub fn decode_element_type(ftype: i32) -> Option<ElementType> {
match ftype {
0 => Some(ggml::Type::F32),
1 => Some(ggml::Type::F16),
2 => Some(ggml::Type::Q4_0),
3 => Some(ggml::Type::Q4_1),
_ => None,
}
}

pub fn encode_element_type(element_type: ElementType) -> Option<i32> {
match element_type {
ggml::Type::F32 => Some(0),
ggml::Type::F16 => Some(1),
ggml::Type::Q4_0 => Some(2),
ggml::Type::Q4_1 => Some(3),
_ => None,
}
}

pub fn decode_element_type_res<T>(ftype: i32) -> Result<ElementType, LoadError<T>> {
match decode_element_type(ftype) {
Some(x) => Ok(x),
None => Err(LoadError::UnsupportedElementType(ftype)),
}
}

pub fn controlflow_to_result<A, B>(x: ControlFlow<A, B>) -> Result<B, LoadError<A>> {
match x {
ControlFlow::Continue(x) => Ok(x),
Expand Down
23 changes: 23 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c;
/// The currently-supported format version for `ggml` files.
pub const FORMAT_VERSION: u32 = 1;

/// The size of a `ggml` object.
pub const OBJECT_SIZE: usize = ggml_sys::GGML_OBJECT_SIZE;

#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
/// The type of a value in `ggml`.
pub enum Type {
Expand All @@ -32,6 +35,12 @@ pub enum Type {
Q4_0,
/// Quantized 4-bit (type 1); used by GPTQ.
Q4_1,
/// Quantized 4-bit (type 2).
Q4_2,
/// Quantized 4-bit (type 3).
Q4_3,
/// Quantized 8-bit (type 0).
Q8_0,
/// Integer 32-bit.
I32,
/// Float 16-bit.
Expand All @@ -44,6 +53,9 @@ impl From<Type> for ggml_sys::ggml_type {
match t {
Type::Q4_0 => ggml_sys::ggml_type_GGML_TYPE_Q4_0,
Type::Q4_1 => ggml_sys::ggml_type_GGML_TYPE_Q4_1,
Type::Q4_2 => ggml_sys::ggml_type_GGML_TYPE_Q4_2,
Type::Q4_3 => ggml_sys::ggml_type_GGML_TYPE_Q4_3,
Type::Q8_0 => ggml_sys::ggml_type_GGML_TYPE_Q8_0,
Type::I32 => ggml_sys::ggml_type_GGML_TYPE_I32,
Type::F16 => ggml_sys::ggml_type_GGML_TYPE_F16,
Type::F32 => ggml_sys::ggml_type_GGML_TYPE_F32,
Expand All @@ -56,6 +68,9 @@ impl TryFrom<ggml_sys::ggml_type> for Type {
match t {
ggml_sys::ggml_type_GGML_TYPE_Q4_0 => Ok(Type::Q4_0),
ggml_sys::ggml_type_GGML_TYPE_Q4_1 => Ok(Type::Q4_1),
ggml_sys::ggml_type_GGML_TYPE_Q4_2 => Ok(Type::Q4_2),
ggml_sys::ggml_type_GGML_TYPE_Q4_3 => Ok(Type::Q4_3),
ggml_sys::ggml_type_GGML_TYPE_Q8_0 => Ok(Type::Q8_0),
ggml_sys::ggml_type_GGML_TYPE_I32 => Ok(Type::I32),
ggml_sys::ggml_type_GGML_TYPE_F16 => Ok(Type::F16),
ggml_sys::ggml_type_GGML_TYPE_F32 => Ok(Type::F32),
Expand All @@ -68,6 +83,9 @@ impl std::fmt::Display for Type {
match self {
Type::Q4_0 => write!(f, "q4_0"),
Type::Q4_1 => write!(f, "q4_1"),
Type::Q4_2 => write!(f, "q4_2"),
Type::Q4_3 => write!(f, "q4_3"),
Type::Q8_0 => write!(f, "q8_0"),
Type::I32 => write!(f, "i32"),
Type::F16 => write!(f, "f16"),
Type::F32 => write!(f, "f32"),
Expand Down Expand Up @@ -510,6 +528,11 @@ pub struct Tensor {
}

impl Tensor {
/// Size of the `ggml_tensor` struct in bytes.
///
/// Exposed for purposes of determining context size.
pub const C_TYPE_SIZE: usize = std::mem::size_of::<ggml_sys::ggml_tensor>();

/// Creates a shared copy of this tensor pointer.
pub fn share(&self) -> Self {
Tensor {
Expand Down
18 changes: 9 additions & 9 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,12 @@ pub struct Convert {
pub directory: PathBuf,

/// File type to convert to
#[arg(long, short = 't', value_enum, default_value_t = ElementType::Q4_0)]
pub element_type: ElementType,
#[arg(long, short = 't', value_enum, default_value_t = FileType::Q4_0)]
pub file_type: FileType,
}

#[derive(Parser, Debug, ValueEnum, Clone, Copy)]
pub enum ElementType {
pub enum FileType {
/// Quantized 4-bit (type 0).
Q4_0,
/// Quantized 4-bit (type 1); used by GPTQ.
Expand All @@ -388,13 +388,13 @@ pub enum ElementType {
/// Float 32-bit.
F32,
}
impl From<ElementType> for llama_rs::ElementType {
fn from(t: ElementType) -> Self {
impl From<FileType> for llama_rs::FileType {
fn from(t: FileType) -> Self {
match t {
ElementType::Q4_0 => llama_rs::ElementType::Q4_0,
ElementType::Q4_1 => llama_rs::ElementType::Q4_1,
ElementType::F16 => llama_rs::ElementType::F16,
ElementType::F32 => llama_rs::ElementType::F32,
FileType::Q4_0 => llama_rs::FileType::MostlyQ4_0,
FileType::Q4_1 => llama_rs::FileType::MostlyQ4_1,
FileType::F16 => llama_rs::FileType::MostlyF16,
FileType::F32 => llama_rs::FileType::F32,
}
}
}
2 changes: 1 addition & 1 deletion llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn main() -> Result<()> {
Args::DumpTokens(args) => dump_tokens(&args)?,
Args::Repl(args) => interactive(&args, false)?,
Args::ChatExperimental(args) => interactive(&args, true)?,
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.element_type.into()),
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.file_type.into()),
}

Ok(())
Expand Down
19 changes: 7 additions & 12 deletions llama-rs/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@ use std::{
vec,
};

use crate::{util, Hyperparameters, Vocabulary};
use ggml_loader::util::encode_element_type;
use crate::{loader_common::FileType, util, Hyperparameters, Vocabulary};

/// Converts a `pth` file to a `ggml` file.
pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) {
pub fn convert_pth_to_ggml(model_directory: &Path, file_type: FileType) {
let tokenizer_path = model_directory.parent().unwrap().join("tokenizer.model");
let vocab = load_vocabulary(tokenizer_path.as_path());

let hparams = load_hyperparameters(model_directory, element_type, &vocab);
let hparams = load_hyperparameters(model_directory, file_type, &vocab);

let model_files = util::find_all_model_files(model_directory).unwrap();

for (i, _file) in model_files.iter().enumerate() {
let fname_out = model_directory.join(format!("rust-model-{element_type}.bin"));
let fname_out = model_directory.join(format!("rust-model-{file_type}.bin"));
let mut file = File::create(fname_out).expect("Unable to create file");
write_header(file.borrow_mut(), &hparams).unwrap();
write_tokens(file.borrow_mut(), &vocab).unwrap();
Expand Down Expand Up @@ -66,11 +65,7 @@ fn load_vocabulary(path: &Path) -> Vocabulary {
}
}

fn load_hyperparameters(
path: &Path,
element_type: ggml::Type,
vocab: &Vocabulary,
) -> Hyperparameters {
fn load_hyperparameters(path: &Path, file_type: FileType, vocab: &Vocabulary) -> Hyperparameters {
#[derive(Deserialize)]
struct HyperParametersJson {
dim: usize,
Expand All @@ -83,7 +78,7 @@ fn load_hyperparameters(
let json = read_to_string(path.join("params.json")).expect("Unable to read file");
let json: HyperParametersJson = serde_json::from_str(&json).expect("Unable to parse json");
Hyperparameters {
element_type,
file_type,
n_ctx: 0,
n_embd: json.dim,
n_head: json.n_heads,
Expand All @@ -107,7 +102,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String
i32::try_from(hparams.n_head).unwrap(),
i32::try_from(hparams.n_layer).unwrap(),
i32::try_from(hparams.n_embd / hparams.n_head).unwrap(),
encode_element_type(hparams.element_type).unwrap(),
hparams.file_type.into(),
];
let mut packed_values: Vec<u8> = vec![];

Expand Down
4 changes: 2 additions & 2 deletions llama-rs/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl InferenceSession {
.map(|(_, tok)| *tok)
.collect();

if self.n_past + prompt_tokens.len() >= model.hparams.n_ctx {
if self.n_past + prompt_tokens.len() >= model.n_ctx() {
return Err(InferenceError::ContextFull);
}

Expand Down Expand Up @@ -96,7 +96,7 @@ impl InferenceSession {
params: &InferenceParameters,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
if self.n_past + 1 >= model.hparams.n_ctx {
if self.n_past + 1 >= model.n_ctx() {
return Err(InferenceError::ContextFull);
}

Expand Down
2 changes: 1 addition & 1 deletion llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use inference_session::{
InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType,
SnapshotError,
};
pub use loader_common::{LoadError, LoadProgress};
pub use loader_common::{FileType, LoadError, LoadProgress};
pub use model::{Hyperparameters, Model};
pub use util::TokenUtf8Buffer;
pub use vocabulary::{TokenBias, TokenId, Vocabulary};
Expand Down
17 changes: 12 additions & 5 deletions llama-rs/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
};

use crate::{
loader_common::FileType,
util::{self, mulf},
LoadError, LoadProgress, Model, TokenId, Vocabulary,
};
Expand Down Expand Up @@ -69,9 +70,9 @@ pub(crate) fn load(
n_head: read_i32(&mut reader)?.try_into()?,
n_layer: read_i32(&mut reader)?.try_into()?,
n_rot: read_i32(&mut reader)?.try_into()?,
element_type: {
file_type: {
let ftype = read_i32(&mut reader)?;
decode_element_type(ftype).ok_or_else(|| LoadError::UnsupportedElementType(ftype))
FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))
}?,
};

Expand Down Expand Up @@ -108,7 +109,13 @@ pub(crate) fn load(
// for the big tensors, we have the option to store the data in 16-bit
// floats or quantized in order to save memory and also to speed up the
// computation
let wtype = hparams.element_type;
let wtype = match hparams.file_type {
FileType::F32 => ggml::Type::F32,
FileType::MostlyF16 => ggml::Type::F16,
FileType::MostlyQ4_0 => ggml::Type::Q4_0,
FileType::MostlyQ4_1 => ggml::Type::Q4_1,
_ => unimplemented!(),
};

let n_embd = hparams.n_embd;
let n_layer = hparams.n_layer;
Expand Down Expand Up @@ -159,7 +166,7 @@ pub(crate) fn load(
(None, None)
};

let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type, mmap);
let mut model = Model::new_loader1(context, hparams, vocabulary, n_ff, wtype, mmap);
match model_type {
ContainerType::GGMF | ContainerType::GGML => {
let file_offset = reader.stream_position()?;
Expand Down Expand Up @@ -421,7 +428,7 @@ fn load_tensor_header_ggmf<'a>(
}

fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option<usize> {
let ftype = decode_element_type(ftype)?;
let ftype = ggml::Type::try_from(ftype).ok()?;
match ftype {
ElementType::Q4_0 | ElementType::Q4_1 => {
assert_eq!(ne[0] % 64, 0);
Expand Down
Loading

0 comments on commit 8254deb

Please sign in to comment.