Skip to content

Commit

Permalink
Merge pull request #248 from huggingface/feature/qindex-select
Browse files Browse the repository at this point in the history
Feature/qindex-select
  • Loading branch information
FL33TW00D authored Aug 27, 2024
2 parents ae754ca + 1881720 commit ae33cee
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Strides, Tensor, TensorDType,
Shape, Tensor, TensorDType,
};
use anyhow::{anyhow, Result};
use core::str::FromStr;
Expand Down
30 changes: 27 additions & 3 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ pub mod gemm;

use crate::{
Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect, InvariantError, OpGuards,
Operation, OperationError, RVec, Storage, StorageView, Tensor, TensorDType, Unary, UnaryOp,
Operation, OperationError, Quantization, Quantizer, RVec, Segments, Storage, StorageView,
Tensor, TensorDType, Unary, UnaryOp,
};
use anyhow::anyhow;
use bytemuck::NoUninit;
Expand Down Expand Up @@ -218,12 +219,35 @@ fn index_select<T: TensorDType>(
Ok(dst)
}

fn qindex_select(op: IndexSelect, dst: Tensor) -> Result<Tensor, OperationError> {
// NOTE: qindex_select is functional but not optimized at all.
// Currently we simply dequantize the entire input tensor to f32 and then call index_select.
// Because of borrowing rules dequantizing also requires a deep clone of the input tensor, which is less than ideal.
// In the future we would rather directly index the raw buffer of the quantized tensor and dequantize only what is required.
// TODO: Add support for direct indexing + partial dequantization
let src = op.src().deep_clone();

// NOTE: Support for other quantization types is dependent on the corresponding dequantization functions.
let src = match src.dt() {
DType::Q8_0F(_) => {
let quantizer = Quantizer::new(Quantization::SInt8);
quantizer.sint8_dequantize(src)
}
_ => return Err(InvariantError::UnsupportedDType(src.dt()).into()),
};
let indices = op.indices().clone();
let dim = op.dim();

index_select::<f32>(IndexSelect::new(src, indices, dim), dst)
}

pub fn cpu_index_select(i: IndexSelect, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
match i.src().dt() {
DType::F32 => index_select::<f32>(i, dst),
DType::F16 => index_select::<f16>(i, dst),
DType::BF16 => index_select::<bf16>(i, dst),
_ => todo!(),
DType::Q8_0F(_) => qindex_select(i, dst),
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}

Expand Down
5 changes: 2 additions & 3 deletions crates/ratchet-core/src/ops/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,8 @@ def index_select(input, indices):
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_index_select_trial(prob.clone(), device, true);

// TODO: impl CPU index_select for quantized tensors
// let device = Device::request_device(DeviceRequest::CPU).unwrap();
// run_index_select_trial(prob, device, true);
let device = Device::request_device(DeviceRequest::CPU).unwrap();
run_index_select_trial(prob, device, true);
}

#[derive(Debug, Clone)]
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/storage/cpu_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl DeviceStorage for CPUBuffer {
DType::I32 => dump_inner(bytemuck::cast_slice::<u8, i32>(bytes), full),
DType::U32 => dump_inner(bytemuck::cast_slice::<u8, u32>(bytes), full),
DType::F16 => dump_inner(bytemuck::cast_slice::<u8, f16>(bytes), full),
_ => unimplemented!("Unable to dump {:?}", dtype),
dt => format!("[{:?} dump not yet supported]", dt),
}
}
}

0 comments on commit ae33cee

Please sign in to comment.