diff --git a/crates/ratchet-core/src/cpu/binary.rs b/crates/ratchet-core/src/cpu/binary.rs index eeb837bb..0e9c0581 100644 --- a/crates/ratchet-core/src/cpu/binary.rs +++ b/crates/ratchet-core/src/cpu/binary.rs @@ -1,49 +1,14 @@ use crate::{ - binary_apply_inplace, Binary, BinaryOp, DType, OpGuards, Operation, OperationError, RVec, - StorageView, Tensor, TensorDType, + binary_apply_inplace, Binary, BinaryOp, CPUOperation, DType, OpGuards, Operation, + OperationError, RVec, StorageView, Tensor, TensorDType, }; use core::marker::PhantomData; use half::{bf16, f16}; -#[derive(Debug)] -pub struct CPU { - op: OP, +pub struct BinaryOps { dtype: PhantomData, } -impl CPU { - pub fn new(op: OP) -> Self { - Self { - op, - dtype: PhantomData, - } - } -} - -impl OpGuards for CPU { - fn check_shapes(&self) { - self.op.check_shapes(); - } - - fn check_dtypes(&self) { - self.op.check_dtypes(); - } -} - -impl Operation for CPU { - fn name(&self) -> &'static str { - self.op.name() - } - - fn compute_view(&self) -> Result { - self.op.compute_view() - } - - fn srcs(&self) -> RVec<&Tensor> { - self.op.srcs() - } -} - macro_rules! impl_cpu_binary_op { ($method_name:ident, $dtype:ident, $op:expr) => { fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result { @@ -55,35 +20,35 @@ macro_rules! impl_cpu_binary_op { macro_rules! impl_cpu_binary { ($dtype:ident) => { - impl CPU<$dtype, Binary> { + impl BinaryOps<$dtype> { impl_cpu_binary_op!(add, $dtype, |lhs, rhs| lhs + rhs); impl_cpu_binary_op!(sub, $dtype, |lhs, rhs| lhs - rhs); impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs); impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs); - } - impl CPUOperation for CPU<$dtype, Binary> { - fn apply_cpu(&self, dst: Tensor) -> Result { - match self.op.op() { - BinaryOp::Add => Self::add(self.op.lhs(), self.op.rhs(), dst), - BinaryOp::Sub => Self::sub(self.op.lhs(), self.op.rhs(), dst), - BinaryOp::Mul => Self::mul(self.op.lhs(), self.op.rhs(), dst), - BinaryOp::Div => Self::div(self.op.lhs(), self.op.rhs(), dst), + pub fn apply(op: &Binary, dst: Tensor) -> Result { + match op.op() { + BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst), + BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst), + BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst), + BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst), } } } }; } +impl CPUOperation for Binary { + fn apply_cpu(&self, dst: Tensor) -> Result { + match dst.dt() { + DType::F32 => BinaryOps::::apply(self, dst), + DType::F16 => BinaryOps::::apply(self, dst), + DType::BF16 => BinaryOps::::apply(self, dst), + _ => todo!(), + } + } +} + impl_cpu_binary!(f32); impl_cpu_binary!(f16); impl_cpu_binary!(bf16); - -pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result { - match dst.dt() { - DType::F32 => CPU::::new(binary).apply_cpu(dst), - DType::F16 => CPU::::new(binary).apply_cpu(dst), - DType::BF16 => CPU::::new(binary).apply_cpu(dst), - _ => todo!(), - } -} diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 92ff7f34..b6140936 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -3,26 +3,16 @@ pub mod gemm; mod unary; use crate::{ - dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, - LazyOp, OpGuards, Operation, OperationError, RVec, Storage, StorageView, Tensor, TensorDType, + dequantize, Binary, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp, + Operation, OperationError, RVec, Storage, Tensor, TensorDType, }; use anyhow::anyhow; use bytemuck::NoUninit; -use core::marker::PhantomData; use half::{bf16, f16}; -pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result { - match dst.dt() { - DType::F32 => binary::CPU::::new(binary).apply_cpu(dst), - DType::F16 => binary::CPU::::new(binary).apply_cpu(dst), - DType::BF16 => binary::CPU::::new(binary).apply_cpu(dst), - _ => todo!(), - } -} - pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result { match op { - LazyOp::Binary(b) => cpu_binary(b, dst), + LazyOp::Binary(b) => b.apply_cpu(dst), LazyOp::Cast(c) => cpu_cast(c, dst), LazyOp::Matmul(m) => m.apply_cpu(dst), LazyOp::Softmax(_s) => todo!(),