Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Q4 quantization support #197

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ fn prepare(tensor_dict: HashMap<String, &PyDict>) -> PyResult<HashMap<String, Te
dtype = match value {
"bool" => Some(Dtype::BOOL),
"int8" => Some(Dtype::I8),
"uint8" => Some(Dtype::U8),
"q4_0" => Some(Dtype::Q4_0),
"q4_1" => Some(Dtype::Q4_1),
"int16" => Some(Dtype::I16),
"uint8" => Some(Dtype::U8),
"uint16" => Some(Dtype::U16),
"int32" => Some(Dtype::I32),
"uint32" => Some(Dtype::U32),
Expand Down Expand Up @@ -874,6 +876,11 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
Dtype::U8 => module.getattr(intern!(py, "uint8"))?.into(),
Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(),
Dtype::BOOL => module.getattr(intern!(py, "bool"))?.into(),
Dtype::Q4_1 | Dtype::Q4_0 => {
return Err(SafetensorError::new_err(format!(
"Dtype not supported by framework: {dtype:?}"
)))
}
dtype => {
return Err(SafetensorError::new_err(format!(
"Dtype not understood: {dtype:?}"
Expand Down
45 changes: 41 additions & 4 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub enum SafeTensorError {
TensorNotFound(String),
/// Invalid information between shape, dtype and the proposed offsets in the file
TensorInvalidInfo,
/// Invalid information between shape, dtype and the proposed offsets in the file
/// The total number of bytes for the buffer is not an integer
QuantizationMisaligned,
/// The offsets declared for tensor with name `String` in the header are invalid
InvalidOffset(String),
/// IoError
Expand Down Expand Up @@ -461,7 +464,11 @@ impl Metadata {
}
start = e;
let nelements: usize = info.shape.iter().product();
let nbytes = nelements * info.dtype.size();
let nbits = nelements * info.dtype.nbits();
if !nbits % 8 == 0 {
return Err(SafeTensorError::QuantizationMisaligned);
}
let nbytes = nbits / 8;
if (e - s) != nbytes {
return Err(SafeTensorError::TensorInvalidInfo);
}
Expand Down Expand Up @@ -570,6 +577,12 @@ pub struct TensorInfo {
pub enum Dtype {
/// Boolan type
BOOL,
/// Quantized int4 format
/// Description <https://github.com/ggerganov/ggml/pull/27>
Q4_0,
/// Quantized int4 format (more precise, but more computation than Q4_0.
/// Description <https://github.com/ggerganov/ggml/pull/27>
Q4_1,
/// Unsigned byte
U8,
/// Signed byte
Expand Down Expand Up @@ -601,20 +614,44 @@ impl Dtype {
pub fn size(&self) -> usize {
match self {
Dtype::BOOL => 1,
Dtype::Q4_0 => 1,
Dtype::Q4_1 => 1,
Dtype::U8 => 1,
Dtype::I8 => 1,
Dtype::I16 => 2,
Dtype::U16 => 2,
Dtype::F16 => 2,
Dtype::BF16 => 2,
Dtype::I32 => 4,
Dtype::U32 => 4,
Dtype::F32 => 4,
Dtype::I64 => 8,
Dtype::U64 => 8,
Dtype::F16 => 2,
Dtype::BF16 => 2,
Dtype::F32 => 4,
Dtype::F64 => 8,
}
}

/// Gives out the size (in bits) of 1 element of this dtype.
/// This is important for sub-byte types like q4_0 and q4_1
pub fn nbits(&self) -> usize {
match self {
Dtype::Q4_0 => 4,
Dtype::Q4_1 => 4,
Dtype::BOOL => 8,
Dtype::U8 => 8,
Dtype::I8 => 8,
Dtype::I16 => 16,
Dtype::U16 => 16,
Dtype::F16 => 16,
Dtype::BF16 => 16,
Dtype::I32 => 32,
Dtype::U32 => 32,
Dtype::F32 => 32,
Dtype::I64 => 64,
Dtype::U64 => 64,
Dtype::F64 => 64,
}
}
}

#[cfg(test)]
Expand Down