Skip to content

Commit

Permalink
chore: R1 and R2 match
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Oct 3, 2024
1 parent 81f4bfc commit 82435eb
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 61 deletions.
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ fn gemm_impl<T: TensorDType>(
}

impl CPUOperation for Matmul {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn run_gemm<T: TensorDType>(
spec: MatmulSpec,
lhs: &Tensor,
Expand Down
16 changes: 8 additions & 8 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ macro_rules! impl_cpu_unary {
impl_cpu_unary_wrapper!($dtype, $conv);

impl CPUOperation for CPU<$dtype, Unary> {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self.op.op() {
UnaryOp::Gelu => Self::gelu(self.op.input(), dst),
UnaryOp::Tanh => Self::tanh(self.op.input(), dst),
Expand Down Expand Up @@ -196,9 +196,9 @@ impl_cpu_unary!(bf16, bf16::from_f32);

pub fn cpu_unary(unary: Unary, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => CPU::<f32, _>::new(unary).apply(dst),
DType::F16 => CPU::<f16, _>::new(unary).apply(dst),
DType::BF16 => CPU::<bf16, _>::new(unary).apply(dst),
DType::F32 => CPU::<f32, _>::new(unary).apply_cpu(dst),
DType::F16 => CPU::<f16, _>::new(unary).apply_cpu(dst),
DType::BF16 => CPU::<bf16, _>::new(unary).apply_cpu(dst),
_ => todo!(),
}
}
Expand All @@ -222,7 +222,7 @@ macro_rules! impl_cpu_binary {
}

impl CPUOperation for CPU<$dtype, Binary> {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
match self.op.op() {
BinaryOp::Add => Self::add(self.op.lhs(), self.op.rhs(), dst),
BinaryOp::Sub => Self::sub(self.op.lhs(), self.op.rhs(), dst),
Expand All @@ -240,9 +240,9 @@ impl_cpu_binary!(bf16);

pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => CPU::<f32, _>::new(binary).apply(dst),
DType::F16 => CPU::<f16, _>::new(binary).apply(dst),
DType::BF16 => CPU::<bf16, _>::new(binary).apply(dst),
DType::F32 => CPU::<f32, _>::new(binary).apply_cpu(dst),
DType::F16 => CPU::<f16, _>::new(binary).apply_cpu(dst),
DType::BF16 => CPU::<bf16, _>::new(binary).apply_cpu(dst),
_ => todo!(),
}
}
Expand Down
112 changes: 62 additions & 50 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,30 +105,29 @@ fn merge(data: &[f32], offset: usize, skip: usize) -> Vec<f32> {
}

fn slice(src: &[f32], start: &[usize], stop: &[usize]) -> Vec<f32> {
let stop_numel: usize = stop.iter().product();
let start_numel: usize = stop.iter().product();
assert!(stop_numel >= start_numel);

let mut dst = vec![0.0; stop_numel - start_numel];

/*
start: [0, 0, 0, 8]
stop: [1, 1, 1, 16]
for
*/

let mut src_idx = 0;
let mut dst_idx = 0;
for i in 0..start.len() {
let mut src_stride = start[i];
let mut dst_stride = 0;
while src_stride < stop[i] {
dst[dst_idx] = src[src_idx];
src_idx += src_stride;
dst_idx += dst_stride;
src_stride += 1;
dst_stride += 1;
assert!(start.len() == stop.len());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let src_shape = [2, 16, 16]; // Corrected input shape
let src_strides = [16 * 16, 16, 1];

let delta: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_shape: Vec<usize> = delta.clone();
let dst_numel: usize = delta.iter().product();

let mut dst = vec![0.0; dst_numel];

for i in 0..dst_numel {
let mut src_index = 0;
let mut tmp = i;
for d in 0..delta.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
src_index += (coord + start[d]) * src_strides[d];
}
dst[i] = src[src_index];
}

dst
Expand Down Expand Up @@ -175,48 +174,61 @@ fn transpose(
fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec<f32> {
println!("Ratchet RoPE");
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();
let el_count = batches * num_heads * seq_len * head_dim;

let half_dim = dim / 2;
let theta = compute_theta(dim, seq_len, base, offset);
println!("Theta: {:?}", theta);
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();
println!("Cos: {:?}", cos);
println!("Sin: {:?}", sin);

let mut intermediate = Vec::with_capacity(el_count);
println!("Cos length: {:?}", cos.len());
println!("Sin length: {:?}", sin.len());

let chunk_offset = half_dim;
let skip = 0;
let x1 = slice(&src, &[0, 0, 0], &[num_heads, seq_len, half_dim]);
let x2 = slice(&src, &[0, 0, half_dim], &[num_heads, seq_len, dim]);
println!("X1: {:?}", x1);
println!("X1 length: {:?}", x1.len());
println!("X2: {:?}", x2);
println!("X2 length: {:?}", x2.len());

let (x1, x2) = chunk_by_offset(&src, chunk_offset, skip);

let (x1_cos, x1_sin): (Vec<f32>, Vec<f32>) = x1
let x1_cos = x1
.iter()
.enumerate()
.map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()]))
.unzip();

let (x2_cos, x2_sin): (Vec<f32>, Vec<f32>) = x2
.map(|(i, x)| x * cos[i % cos.len()])
.collect::<Vec<f32>>();
let x2_sin = x2
.iter()
.enumerate()
.map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()]))
.unzip();
.map(|(i, x)| x * sin[i % sin.len()])
.collect::<Vec<f32>>();

x1_cos.iter().zip(x2_sin).for_each(|(x1_cos, x2_sin)| {
intermediate.push(x1_cos - x2_sin);
});
let r1 = x1_cos
.iter()
.zip(x2_sin.iter())
.map(|(x1, x2)| x1 - x2)
.collect::<Vec<f32>>();

x1_sin.iter().zip(x2_cos).for_each(|(x1_sin, x2_cos)| {
intermediate.push(x1_sin + x2_cos);
});
let x1_sin = x1
.iter()
.enumerate()
.map(|(i, x)| x * sin[i % sin.len()])
.collect::<Vec<f32>>();
let x2_cos = x2
.iter()
.enumerate()
.map(|(i, x)| x * cos[i % cos.len()])
.collect::<Vec<f32>>();
let r2 = x1_sin
.iter()
.zip(x2_cos.iter())
.map(|(x1, x2)| x1 + x2)
.collect::<Vec<f32>>();

let skip = head_dim.abs_diff(dim);
let mut dst = merge(&intermediate, half_dim, skip);
println!("R1: {:?}", r1);
println!("R2: {:?}", r2);

if dim < head_dim {
let offset = (el_count / head_dim) * dim;
let appendix = &mut src[offset..].to_vec();
dst.append(appendix);
}
dst
vec![]
}

fn rope_2(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec<f32> {
Expand Down
3 changes: 3 additions & 0 deletions crates/ratchet-core/src/cpu/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
use crate::{Slice, Tensor};

pub fn cpu_slice(op: Slice, dst: Tensor) -> Result<Tensor, OperationError> {}
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,5 +363,5 @@ pub trait GPUOperation: Operation {
}

pub trait CPUOperation: Operation {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError>;
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError>;
}
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ impl Tensor {
match self.op().clone() {
LazyOp::Binary(b) => cpu_binary(b, dst).ok(),
LazyOp::Cast(c) => cpu_cast(c, dst).ok(),
LazyOp::Matmul(m) => m.apply(dst).ok(),
LazyOp::Matmul(m) => m.apply_cpu(dst).ok(),
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(r) => cpu_rope(r, dst).ok(),
LazyOp::Unary(u) => cpu_unary(u, dst).ok(),
Expand Down

0 comments on commit 82435eb

Please sign in to comment.