From 4208bc710c02a60ca488e728cb943887a8bae8be Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 12 Apr 2023 19:03:36 -0600 Subject: [PATCH] Support GGML q4_0 and q4_1 quantization. --- README.md | 8 ++ smolrwkv-cli/src/args.rs | 14 ++- smolrwkv-cli/src/main.rs | 18 ++-- smolrwkv-cli/src/util.rs | 1 + smolrwkv/Cargo.toml | 3 + smolrwkv/src/ggml/loader.rs | 206 +++++++++++++++++++++++++++--------- 6 files changed, 193 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 40089f8..218feb8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ # Smol Rust RWKV +## Dev Branch Note + +This branch currently needs a patched version of `ggml` and `ggml-sys` from the `llama-rs` project. + +You can find it here: https://github.com/KerfuffleV2/llama-rs/tree/experiment-ggml-map-ops + +`smolrwkv/Cargo.toml` is set up to look for it checked out in `../llama-rs` relative to this repo's directory. + ## What is it? A simple example of the RWKV approach to language models written in Rust by someone that diff --git a/smolrwkv-cli/src/args.rs b/smolrwkv-cli/src/args.rs index 5ccf4aa..713a620 100644 --- a/smolrwkv-cli/src/args.rs +++ b/smolrwkv-cli/src/args.rs @@ -10,11 +10,23 @@ const DEFAULT_MODEL: &str = "./RWKV-4-Pile-430M-20220808-8066.safetensors"; const DEFAULT_TOKENIZER: &str = "./20B_tokenizer.json"; #[derive(Debug, Clone, Copy, ValueEnum)] -#[value(rename_all = "lower")] pub enum EvalType { + #[value(name = "ndf32")] + /// ndarray-backed 32 bit floats. Uses a lot of memory. NDf32, + #[value(name = "ndq8")] + /// ndarray-backed 8 bit quantized. Better memory usage but quite slow. NDu8, + #[value(name = "ggmlf32")] + /// GGML-backed 32 bit. As above, uses a lot of memory. GGMLf32, + #[value(name = "ggmlq4_0")] + /// GGML-backed 4 bit quantized, method 1. Poor quality. + GGMLQ4_0, + #[value(name = "ggmlq4_1")] + /// GGML-backed 4 bit quantized, method 2. Decenent quality, + /// but slower (to load?) + GGMLQ4_1, } #[derive(Clone, Debug, Parser)] diff --git a/smolrwkv-cli/src/main.rs b/smolrwkv-cli/src/main.rs index 4116ded..e77f2ec 100644 --- a/smolrwkv-cli/src/main.rs +++ b/smolrwkv-cli/src/main.rs @@ -59,7 +59,7 @@ fn go() -> Result<()> { args::EvalType::NDf32 => { Ctx::NdFloat32(run_threadlimited(args.max_load_threads, move || { anyhow::Ok({ - info!("Model type: non-quantized (full 32bit)."); + info!("Backend type: NDArray non-quantized (full 32bit)."); S::context::RWKVContext::>::new( tdm.try_into()?, tokenizer, @@ -71,16 +71,22 @@ fn go() -> Result<()> { args::EvalType::NDu8 => { Ctx::NdQuant8(run_threadlimited(args.max_load_threads, move || { anyhow::Ok({ - info!("Model type: 8 bit-quantized weights."); + info!("Backend type: NDArray 8 bit-quantized weights."); S::context::RWKVContext::::new(tdm.try_into()?, tokenizer) }) })?) } - args::EvalType::GGMLf32 => { - use smolrwkv::ggml::context::RWKVContext; - info!("Model type: GGML 32 bit."); + args::EvalType::GGMLf32 | args::EvalType::GGMLQ4_0 | args::EvalType::GGMLQ4_1 => { + use smolrwkv::ggml::{context::RWKVContext, loader::RwkvGgmlType}; + let wtype = match args.eval_mode { + args::EvalType::GGMLf32 => RwkvGgmlType::Float32, + args::EvalType::GGMLQ4_0 => RwkvGgmlType::Q4_0, + args::EvalType::GGMLQ4_1 => RwkvGgmlType::Q4_1, + _ => panic!("Impossible: Bad eval mode!"), + }; + info!("Backend type: GGML {wtype:?}"); Ctx::GgmlFloat32(RWKVContext::new( - tdm.try_into()?, + (wtype, tdm).try_into()?, tokenizer, args.max_eval_threads, )) diff --git a/smolrwkv-cli/src/util.rs b/smolrwkv-cli/src/util.rs index 0d59732..fd07a5e 100644 --- a/smolrwkv-cli/src/util.rs +++ b/smolrwkv-cli/src/util.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] use std::io::Write; use anyhow::Result; diff --git a/smolrwkv/Cargo.toml b/smolrwkv/Cargo.toml index aa33ee4..0cb78f0 100644 --- a/smolrwkv/Cargo.toml +++ b/smolrwkv/Cargo.toml @@ -31,3 +31,6 @@ optional = true [dependencies.ggml] path = "../../llama-rs/ggml" + +[dependencies.ggml-sys] +path = "../../llama-rs/ggml-sys" diff --git a/smolrwkv/src/ggml/loader.rs b/smolrwkv/src/ggml/loader.rs index 347ae28..0b4b0f6 100644 --- a/smolrwkv/src/ggml/loader.rs +++ b/smolrwkv/src/ggml/loader.rs @@ -7,7 +7,7 @@ use anyhow::{anyhow, Error, Ok as AOk, Result}; use ndarray::{Array1, Array2}; use tracing::{info, instrument}; -use ggml::{Context, Tensor}; +use ggml::{Context, Tensor, Type as GT}; use super::model::*; use crate::{ @@ -16,10 +16,34 @@ use crate::{ }; type Ty = f32; -const GT32: ggml::Type = ggml::Type::F32; +const GT32: ggml::Type = GT::F32; -/// LayerMap helper type to avoid repetition. -type BuildCtx<'a, 'b> = (&'a Context, &'a HashMap>); +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum RwkvGgmlType { + Float32, + Q4_0, + Q4_1, +} + +#[allow(clippy::from_over_into)] +// Note: Only Into here because can't handle all GGML types. +impl Into for RwkvGgmlType { + fn into(self) -> ggml::Type { + match self { + RwkvGgmlType::Float32 => GT::F32, + RwkvGgmlType::Q4_0 => GT::Q4_0, + RwkvGgmlType::Q4_1 => GT::Q4_1, + } + } +} + +struct BuildCtx<'a, 'b> { + n_layers: usize, + lnum: usize, + ctx: &'a Context, + lm: &'a HashMap>, + wtype: RwkvGgmlType, +} #[repr(transparent)] struct Tents(Tensor); @@ -30,6 +54,45 @@ impl From for Tents { } } +fn quantize(bctx: &BuildCtx<'_, '_>, arr: Array2) -> Tensor { + let wtype = bctx.wtype; + + let nels = arr.len(); + let shp = arr.shape(); + let in_size = nels * 4; + // FIXME: Verify this is safe, but 32bit -> 4bit shouldn't take more than 8bits per element + // plus maybe an extra block. Riiight? + let mut output: Vec = Vec::with_capacity((in_size / 4) + ggml::blck_size(wtype.into())); + let mut hist = [0i64; 16]; + output.fill(0); + let out_size = unsafe { + match wtype { + RwkvGgmlType::Q4_0 => ggml_sys::ggml_quantize_q4_0( + arr.as_ptr(), + output.as_mut_ptr() as *mut std::ffi::c_void, + nels as i32, + shp[1] as i32, + hist.as_mut_ptr(), + ), + RwkvGgmlType::Q4_1 => ggml_sys::ggml_quantize_q4_1( + arr.as_ptr(), + output.as_mut_ptr() as *mut std::ffi::c_void, + nels as i32, + shp[1] as i32, + hist.as_mut_ptr(), + ), + _ => panic!("Bad weight type!"), + } + }; + info!( + "--> QUANT: len {in_size} -> {out_size} ({})", + in_size - out_size + ); + let t = bctx.ctx.new_tensor_2d(wtype.into(), shp[1], shp[0]); + unsafe { (t.data() as *mut u8).copy_from_nonoverlapping(output.as_ptr(), out_size) } + t +} + /// Helper function for extracting a 1d tensor from the HashMap by string key. /// Takes a closure to convert from the TensorData struct to a usable format. fn gk1(ctx: &Context, lm: &HashMap>, key: &str) -> Result { @@ -56,6 +119,24 @@ fn gk2(ctx: &Context, lm: &HashMap>, key: &str) -> Result Ok(t) } +fn qgk2(bctx: &BuildCtx<'_, '_>, key: &str) -> Result { + if bctx.wtype == RwkvGgmlType::Float32 { + return gk2(bctx.ctx, bctx.lm, key); + } + + let arr = >>::convert_tensor( + bctx.lm.get(key).ok_or_else(|| anyhow!("Bad format"))?, + )?; + info!( + "[{}/{}]: Quantizing {key}({:?})", + bctx.n_layers, + bctx.lnum + 1, + arr.shape() + ); + let t = quantize(bctx, arr); + Ok(t) +} + impl From<(&Context, Array1)> for Tents { fn from((ctx, arr): (&Context, Array1)) -> Self { let shp = arr.shape(); @@ -75,23 +156,24 @@ impl From<(&Context, Array2)> for Tents { } } -impl TryFrom<(usize, BuildCtx<'_, '_>)> for LayerNorm { +impl TryFrom<(usize, &BuildCtx<'_, '_>)> for LayerNorm { type Error = Error; #[instrument(skip_all, name = "convert_layer_norm", level = "DEBUG")] - fn try_from((idx, (ctx, lm)): (usize, BuildCtx<'_, '_>)) -> Result { + fn try_from((idx, bctx): (usize, &BuildCtx<'_, '_>)) -> Result { Ok(Self { - weight: gk1(ctx, lm, &format!("ln{idx}.weight"))?, - bias: gk1(ctx, lm, &format!("ln{idx}.bias"))?, + weight: gk1(bctx.ctx, bctx.lm, &format!("ln{idx}.weight"))?, + bias: gk1(bctx.ctx, bctx.lm, &format!("ln{idx}.bias"))?, }) } } -impl TryFrom> for AttTime { +impl TryFrom<&BuildCtx<'_, '_>> for AttTime { type Error = Error; #[instrument(skip_all, err, name = "convert_attn_time_mix", level = "DEBUG")] - fn try_from((ctx, lm): BuildCtx<'_, '_>) -> Result { + fn try_from(bctx: &BuildCtx<'_, '_>) -> Result { + let (ctx, lm) = (bctx.ctx, bctx.lm); let mut decay = >>::convert_tensor( lm.get("att.time_decay") .ok_or_else(|| anyhow!("Bad format"))?, @@ -108,26 +190,27 @@ impl TryFrom> for AttTime { } } -impl TryFrom> for Attention { +impl TryFrom<&BuildCtx<'_, '_>> for Attention { type Error = Error; #[instrument(skip_all, name = "convert_att", level = "DEBUG")] - fn try_from(bctx @ (ctx, lm): BuildCtx<'_, '_>) -> Result { + fn try_from(bctx: &BuildCtx<'_, '_>) -> Result { Ok(Self { - key_weight: gk2(ctx, lm, "att.key.weight")?, - value_weight: gk2(ctx, lm, "att.value.weight")?, - output_weight: gk2(ctx, lm, "att.output.weight")?, - receptance_weight: gk2(ctx, lm, "att.receptance.weight")?, + key_weight: qgk2(bctx, "att.key.weight")?, + value_weight: qgk2(bctx, "att.value.weight")?, + output_weight: qgk2(bctx, "att.output.weight")?, + receptance_weight: qgk2(bctx, "att.receptance.weight")?, time: AttTime::try_from(bctx)?, }) } } -impl TryFrom> for FFNTime { +impl TryFrom<&BuildCtx<'_, '_>> for FFNTime { type Error = Error; #[instrument(skip_all, name = "convert_ffn_time_mix", level = "DEBUG")] - fn try_from((ctx, lm): BuildCtx<'_, '_>) -> Result { + fn try_from(bctx: &BuildCtx<'_, '_>) -> Result { + let (ctx, lm) = (bctx.ctx, bctx.lm); Ok(Self { mix_k: Mix(gk1(ctx, lm, "ffn.time_mix_k")?), mix_r: Mix(gk1(ctx, lm, "ffn.time_mix_r")?), @@ -135,25 +218,25 @@ impl TryFrom> for FFNTime { } } -impl TryFrom> for FeedForwardNetwork { +impl TryFrom<&BuildCtx<'_, '_>> for FeedForwardNetwork { type Error = Error; #[instrument(skip_all, name = "convert_ffn", level = "DEBUG")] - fn try_from(bctx @ (ctx, lm): BuildCtx<'_, '_>) -> Result { + fn try_from(bctx: &BuildCtx<'_, '_>) -> Result { Ok(FeedForwardNetwork { - key_weight: gk2(ctx, lm, "ffn.key.weight")?, - value_weight: gk2(ctx, lm, "ffn.value.weight")?, - receptance_weight: gk2(ctx, lm, "ffn.receptance.weight")?, + key_weight: qgk2(bctx, "ffn.key.weight")?, + value_weight: qgk2(bctx, "ffn.value.weight")?, + receptance_weight: qgk2(bctx, "ffn.receptance.weight")?, time: FFNTime::try_from(bctx)?, }) } } -impl TryFrom> for RWKVLayer { +impl TryFrom<&BuildCtx<'_, '_>> for RWKVLayer { type Error = Error; #[instrument(skip_all, name = "convert_layer", level = "DEBUG")] - fn try_from(bctx: BuildCtx<'_, '_>) -> Result { + fn try_from(bctx: &BuildCtx<'_, '_>) -> Result { Ok(Self { ln_tm: LayerNorm::try_from((1, bctx))?, ln_cm: LayerNorm::try_from((2, bctx))?, @@ -163,15 +246,16 @@ impl TryFrom> for RWKVLayer { } } -impl TryFrom> for RWKV { +impl TryFrom<(RwkvGgmlType, TensorDataMap<'_>)> for RWKV { type Error = Error; #[instrument(skip_all, name = "load_model")] - fn try_from(tensors: TensorDataMap<'_>) -> Result { + fn try_from((wtype, tensors): (RwkvGgmlType, TensorDataMap<'_>)) -> Result { info!("Discovering model structure."); let mut layers = Vec::with_capacity(32); let mut nlm = HashMap::default(); - tensors.into_iter().try_for_each(|(mut name, tensor)| { + tensors.iter().try_for_each(|(name, tensor)| { + let mut name = name.to_owned(); if let Some(rest) = name.strip_prefix("blocks.") { let result = rest.split_once('.').ok_or_else(|| anyhow!("Bad format"))?; let lnum: usize = result.0.parse()?; @@ -180,10 +264,10 @@ impl TryFrom> for RWKV { } name = result.1.to_string(); - layers[lnum].insert(name, tensor); + layers[lnum].insert(name, tensor.clone()); AOk(()) } else { - nlm.insert(name, tensor); + nlm.insert(name.to_owned(), tensor.clone()); Ok(()) } })?; @@ -196,32 +280,30 @@ impl TryFrom> for RWKV { ); anyhow::ensure!(!nlm.is_empty(), "Missing non-layer tensors!"); - // FIXME; Real stuff here. - let ctx_size = 12 * 1024 * 1024 * 1024; + let (n_vocab, n_embed) = nlm + .get("emb.weight") + .ok_or_else(|| anyhow!("Bad format")) + .map(|x| { + let shp = &x.shape; + assert_eq!(shp.len(), 2, "Bad shape for emb.weight!"); + (shp[0], shp[1]) + // + })?; + // FIXME; Better stuff here. + let ctx_size = match wtype { + RwkvGgmlType::Float32 => (n_layers * n_embed * n_vocab) * 3, + RwkvGgmlType::Q4_0 | RwkvGgmlType::Q4_1 => (n_layers + 2) * n_embed * n_vocab, + }; let ctx = ggml::Context::init(ctx_size); - let ln0 = crate::simple::model::LayerNorm::::try_from((0, &layers[0]))?; - - info!("Loading {n_layers} layer(s):"); - let layers = layers - .into_iter() - .map(|lm| { - print!("."); - stdout().flush().ok(); - RWKVLayer::try_from((&ctx, &lm)) - }) - .collect::, _>>()?; - - println!(); - info!("Precomputing embedding..."); // It's possible to just precompute the embeddings in advance. - let (emb, n_embed, n_vocab) = { + let emb = { + let ln0 = crate::simple::model::LayerNorm::::try_from((0, &layers[0]))?; + info!("Precomputing embedding..."); let mut emba = >>::convert_tensor( nlm.get("emb.weight").ok_or_else(|| anyhow!("Bad format"))?, )?; - let embashp = emba.shape(); - let (n_vocab, n_embed) = (embashp[0], embashp[1]); (0..n_vocab).for_each(|idx| { use crate::model_traits::RunLayerNorm; @@ -232,10 +314,34 @@ impl TryFrom> for RWKV { idxemb.copy_from_slice(&ln0.norm(&idxemb).into_raw_vec()); }); drop(ln0); + let Tents(emb) = (&ctx, emba).into(); - (emb, n_embed, n_vocab) + emb }; + info!("Loading {n_layers} layer(s):"); + + let layers = layers + .into_iter() + .enumerate() + .map(|(lnum, lm)| { + let bctx = BuildCtx { + n_layers, + lnum, + ctx: &ctx, + lm: &lm, + wtype, + }; + if wtype == RwkvGgmlType::Float32 { + print!("."); + stdout().flush().ok(); + } + RWKVLayer::try_from(&bctx) + }) + .collect::, _>>()?; + + println!(); + info!("Loading non-layer tensors."); Ok(RWKV {