Skip to content

Commit

Permalink
Support GGML q4_0 and q4_1 quantization.
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Apr 13, 2023
1 parent 7cf5945 commit 4208bc7
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 57 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Smol Rust RWKV

## Dev Branch Note

This branch currently needs a patched version of `ggml` and `ggml-sys` from the `llama-rs` project.

You can find it here: https://github.com/KerfuffleV2/llama-rs/tree/experiment-ggml-map-ops

`smolrwkv/Cargo.toml` is set up to look for it checked out in `../llama-rs` relative to this repo's directory.

## What is it?

A simple example of the RWKV approach to language models written in Rust by someone that
Expand Down
14 changes: 13 additions & 1 deletion smolrwkv-cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,23 @@ const DEFAULT_MODEL: &str = "./RWKV-4-Pile-430M-20220808-8066.safetensors";
const DEFAULT_TOKENIZER: &str = "./20B_tokenizer.json";

#[derive(Debug, Clone, Copy, ValueEnum)]
#[value(rename_all = "lower")]
pub enum EvalType {
#[value(name = "ndf32")]
/// ndarray-backed 32 bit floats. Uses a lot of memory.
NDf32,
#[value(name = "ndq8")]
/// ndarray-backed 8 bit quantized. Better memory usage but quite slow.
NDu8,
#[value(name = "ggmlf32")]
/// GGML-backed 32 bit. As above, uses a lot of memory.
GGMLf32,
#[value(name = "ggmlq4_0")]
/// GGML-backed 4 bit quantized, method 1. Poor quality.
GGMLQ4_0,
#[value(name = "ggmlq4_1")]
/// GGML-backed 4 bit quantized, method 2. Decenent quality,
/// but slower (to load?)
GGMLQ4_1,
}

#[derive(Clone, Debug, Parser)]
Expand Down
18 changes: 12 additions & 6 deletions smolrwkv-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn go() -> Result<()> {
args::EvalType::NDf32 => {
Ctx::NdFloat32(run_threadlimited(args.max_load_threads, move || {
anyhow::Ok({
info!("Model type: non-quantized (full 32bit).");
info!("Backend type: NDArray non-quantized (full 32bit).");
S::context::RWKVContext::<FloatType, Array2<FloatType>>::new(
tdm.try_into()?,
tokenizer,
Expand All @@ -71,16 +71,22 @@ fn go() -> Result<()> {
args::EvalType::NDu8 => {
Ctx::NdQuant8(run_threadlimited(args.max_load_threads, move || {
anyhow::Ok({
info!("Model type: 8 bit-quantized weights.");
info!("Backend type: NDArray 8 bit-quantized weights.");
S::context::RWKVContext::<FloatType, TensorQ2>::new(tdm.try_into()?, tokenizer)
})
})?)
}
args::EvalType::GGMLf32 => {
use smolrwkv::ggml::context::RWKVContext;
info!("Model type: GGML 32 bit.");
args::EvalType::GGMLf32 | args::EvalType::GGMLQ4_0 | args::EvalType::GGMLQ4_1 => {
use smolrwkv::ggml::{context::RWKVContext, loader::RwkvGgmlType};
let wtype = match args.eval_mode {
args::EvalType::GGMLf32 => RwkvGgmlType::Float32,
args::EvalType::GGMLQ4_0 => RwkvGgmlType::Q4_0,
args::EvalType::GGMLQ4_1 => RwkvGgmlType::Q4_1,
_ => panic!("Impossible: Bad eval mode!"),
};
info!("Backend type: GGML {wtype:?}");
Ctx::GgmlFloat32(RWKVContext::new(
tdm.try_into()?,
(wtype, tdm).try_into()?,
tokenizer,
args.max_eval_threads,
))
Expand Down
1 change: 1 addition & 0 deletions smolrwkv-cli/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![allow(dead_code)]
use std::io::Write;

use anyhow::Result;
Expand Down
3 changes: 3 additions & 0 deletions smolrwkv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ optional = true

[dependencies.ggml]
path = "../../llama-rs/ggml"

[dependencies.ggml-sys]
path = "../../llama-rs/ggml-sys"
Loading

0 comments on commit 4208bc7

Please sign in to comment.