Skip to content

Commit

Permalink
chore: improve
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 22, 2024
1 parent 5022af4 commit ac13243
Showing 1 changed file with 114 additions and 12 deletions.
126 changes: 114 additions & 12 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,105 @@ impl Tensor {
}
}

pub fn deep_clone(&self) -> Tensor {
let storage_guard = self.storage();
let storage = storage_guard.as_ref().unwrap();
let cloned_storage = storage.deep_clone().unwrap();
Tensor::new(
LazyOp::Const,
self.view.clone(),
Some(cloned_storage),
self.device.clone(),
)
}
}

impl Tensor {
pub fn all_close(&self, other: &Self, atol: f32, rtol: f32) -> anyhow::Result<()> {
if self.shape() != other.shape() {
anyhow::bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
}

let self_nd = self.to_ndarray_view::<f32>();
let other_nd = other.to_ndarray_view::<f32>();
let mut stats = CloseStats::new(atol, rtol);

ndarray::indices_of(&self_nd).into_iter().for_each(|idx| {
let (a, b) = (self_nd[&idx], other_nd[&idx]);
stats.update(&a, &b, idx);
});

if stats.fail_count > 0 {
anyhow::bail!(
"{} samples not close - AVGE={} MAE={} at {:?}",
stats.fail_count,
stats.avg_error(),
stats.max_abs_error,
stats.max_abs_error_idxs,
);
} else {
println!(
"All close - AVGE={} MAE={} at {:?}",
stats.avg_error(),
stats.max_abs_error,
stats.max_abs_error_idxs
);
Ok(())
}
}
}

struct CloseStats {
total_error: f32,
max_abs_error: f32,
max_abs_error_idxs: Option<ndarray::IxDyn>,
element_count: usize,
fail_count: usize,
atol: f32,
rtol: f32,
}

impl CloseStats {
fn new(atol: f32, rtol: f32) -> Self {
Self {
total_error: 0.0,
max_abs_error: 0.0,
max_abs_error_idxs: None,
element_count: 0,
fail_count: 0,
atol,
rtol,
}
}

fn update(&mut self, a: &f32, b: &f32, index: ndarray::IxDyn) {
let abs_diff = (a - b).abs();
self.total_error += abs_diff;
self.element_count += 1;

if abs_diff > self.max_abs_error {
self.max_abs_error = abs_diff;
self.max_abs_error_idxs = Some(index);
}

if !self.is_close(a, b, abs_diff) {
self.fail_count += 1;
}
}

fn avg_error(&self) -> f32 {
self.total_error / self.element_count as f32
}

fn is_close(&self, a: &f32, b: &f32, abs_diff: f32) -> bool {
(a.is_nan() && b.is_nan())
|| (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
|| abs_diff <= self.atol + self.rtol * b.abs()
}
}

/// Conversion to and from numpy arrays
impl Tensor {
#[cfg(feature = "pyo3")]
pub fn into_ndarray<T: TensorDType>(self) -> ArrayD<T> {
assert!(self.device().is_cpu());
Expand All @@ -346,6 +445,20 @@ impl Tensor {
}
}

#[cfg(feature = "pyo3")]
pub fn to_ndarray_view<T: TensorDType>(&self) -> ArrayViewD<T> {
assert!(self.device().is_cpu());
let shape = self.shape().to_vec();
if self.num_bytes() != 0 {
let storage_guard = self.storage();
let buffer = storage_guard.as_ref().unwrap().try_cpu().unwrap();
let (ptr, _) = buffer.inner().into_raw_parts();
unsafe { ArrayViewD::from_shape_ptr(shape, ptr as *const T) }
} else {
ArrayViewD::from_shape(shape, &[]).unwrap()
}
}

#[cfg(feature = "pyo3")]
pub fn to_py<'s, 'p: 's, T: TensorDType + numpy::Element>(
&'s self,
Expand All @@ -358,18 +471,6 @@ impl Tensor {
);
PyArray::from_owned_array(*py, self.deep_clone().into_ndarray::<T>())
}

pub fn deep_clone(&self) -> Tensor {
let storage_guard = self.storage();
let storage = storage_guard.as_ref().unwrap();
let cloned_storage = storage.deep_clone().unwrap();
Tensor::new(
LazyOp::Const,
self.view.clone(),
Some(cloned_storage),
self.device.clone(),
)
}
}

#[cfg(feature = "pyo3")]
Expand Down Expand Up @@ -444,6 +545,7 @@ def matmul(a, b):
let c_gpu = a_gpu.matmul(&b_gpu)?;
c_gpu.resolve()?;
let d_gpu = c_gpu.to(Device::CPU)?;
ground?.all_close(&d_gpu, 1e-5, 1e-5)?;
Ok(())
}
}

0 comments on commit ac13243

Please sign in to comment.