Skip to content

Commit

Permalink
refactor: yoink CPU<> struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 4, 2024
1 parent 84d6496 commit ca70c05
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 69 deletions.
77 changes: 21 additions & 56 deletions crates/ratchet-core/src/cpu/binary.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,14 @@
use crate::{
binary_apply_inplace, Binary, BinaryOp, DType, OpGuards, Operation, OperationError, RVec,
StorageView, Tensor, TensorDType,
binary_apply_inplace, Binary, BinaryOp, CPUOperation, DType, OpGuards, Operation,
OperationError, RVec, StorageView, Tensor, TensorDType,
};
use core::marker::PhantomData;
use half::{bf16, f16};

#[derive(Debug)]
pub struct CPU<T: TensorDType, OP: Operation> {
op: OP,
pub struct BinaryOps<T: TensorDType> {
dtype: PhantomData<T>,
}

impl<T: TensorDType, OP: Operation> CPU<T, OP> {
pub fn new(op: OP) -> Self {
Self {
op,
dtype: PhantomData,
}
}
}

impl<T: TensorDType, OP: Operation> OpGuards for CPU<T, OP> {
fn check_shapes(&self) {
self.op.check_shapes();
}

fn check_dtypes(&self) {
self.op.check_dtypes();
}
}

impl<T: TensorDType, OP: Operation> Operation for CPU<T, OP> {
fn name(&self) -> &'static str {
self.op.name()
}

fn compute_view(&self) -> Result<StorageView, OperationError> {
self.op.compute_view()
}

fn srcs(&self) -> RVec<&Tensor> {
self.op.srcs()
}
}

macro_rules! impl_cpu_binary_op {
($method_name:ident, $dtype:ident, $op:expr) => {
fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result<Tensor, OperationError> {
Expand All @@ -55,35 +20,35 @@ macro_rules! impl_cpu_binary_op {

macro_rules! impl_cpu_binary {
($dtype:ident) => {
impl CPU<$dtype, Binary> {
impl BinaryOps<$dtype> {
impl_cpu_binary_op!(add, $dtype, |lhs, rhs| lhs + rhs);
impl_cpu_binary_op!(sub, $dtype, |lhs, rhs| lhs - rhs);
impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs);
impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs);
}

impl CPUOperation for CPU<$dtype, Binary> {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self.op.op() {
BinaryOp::Add => Self::add(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Sub => Self::sub(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Mul => Self::mul(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Div => Self::div(self.op.lhs(), self.op.rhs(), dst),
pub fn apply(op: &Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match op.op() {
BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst),
BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst),
BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst),
BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst),
}
}
}
};
}

impl CPUOperation for Binary {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => BinaryOps::<f32>::apply(self, dst),
DType::F16 => BinaryOps::<f16>::apply(self, dst),
DType::BF16 => BinaryOps::<bf16>::apply(self, dst),
_ => todo!(),
}
}
}

impl_cpu_binary!(f32);
impl_cpu_binary!(f16);
impl_cpu_binary!(bf16);

pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => CPU::<f32, _>::new(binary).apply_cpu(dst),
DType::F16 => CPU::<f16, _>::new(binary).apply_cpu(dst),
DType::BF16 => CPU::<bf16, _>::new(binary).apply_cpu(dst),
_ => todo!(),
}
}
16 changes: 3 additions & 13 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,16 @@ pub mod gemm;
mod unary;

use crate::{
dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError,
LazyOp, OpGuards, Operation, OperationError, RVec, Storage, StorageView, Tensor, TensorDType,
dequantize, Binary, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp,
Operation, OperationError, RVec, Storage, Tensor, TensorDType,
};
use anyhow::anyhow;
use bytemuck::NoUninit;
use core::marker::PhantomData;
use half::{bf16, f16};

pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => binary::CPU::<f32, _>::new(binary).apply_cpu(dst),
DType::F16 => binary::CPU::<f16, _>::new(binary).apply_cpu(dst),
DType::BF16 => binary::CPU::<bf16, _>::new(binary).apply_cpu(dst),
_ => todo!(),
}
}

pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError> {
match op {
LazyOp::Binary(b) => cpu_binary(b, dst),
LazyOp::Binary(b) => b.apply_cpu(dst),
LazyOp::Cast(c) => cpu_cast(c, dst),
LazyOp::Matmul(m) => m.apply_cpu(dst),
LazyOp::Softmax(_s) => todo!(),
Expand Down

0 comments on commit ca70c05

Please sign in to comment.