Skip to content

Commit

Permalink
Add attribute test macro
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed Dec 23, 2023
1 parent ceb78d3 commit d436831
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 57 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-test-macro",
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
Expand Down
3 changes: 2 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.2", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.2", optional = true }
candle-test-macro = { path = "../candle-test-macro", version = "0.3.2" }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
Expand All @@ -38,7 +39,7 @@ criterion = { workspace = true }


[features]
default = []
default = ["cuda"]
cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
Expand Down
116 changes: 60 additions & 56 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
use candle_core::{test_utils, DType, Device, IndexOp, Result, Tensor, D};
use candle_test_macro::test_device;

#[test_device]
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
let (dim1, dim2) = tensor.dims2()?;
Expand All @@ -8,6 +10,7 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}

#[test_device]
fn ones(device: &Device) -> Result<()> {
assert_eq!(
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
Expand All @@ -32,6 +35,8 @@ fn ones(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn full(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
Expand All @@ -40,6 +45,8 @@ fn full(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
Expand All @@ -60,6 +67,8 @@ fn arange(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
Expand All @@ -75,6 +84,8 @@ fn add_mul(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn tensor_2d(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
Expand All @@ -85,6 +96,8 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn clamp(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
Expand All @@ -96,6 +109,8 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -144,6 +159,8 @@ fn unary_op(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
Expand Down Expand Up @@ -173,6 +190,8 @@ fn binary_op(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn transpose(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?.t()?;
Expand All @@ -188,6 +207,8 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
Expand All @@ -204,6 +225,8 @@ fn var(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -301,6 +324,8 @@ fn sum(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn min(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -363,6 +388,8 @@ fn min(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn max(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -425,6 +452,8 @@ fn max(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn argmin(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -499,6 +528,8 @@ fn argmin(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn argmax(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -573,6 +604,8 @@ fn argmax(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn narrow(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -603,6 +636,8 @@ fn narrow(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn broadcast(device: &Device) -> Result<()> {
let data = &[3f32, 1., 4.];
let tensor = Tensor::new(data, device)?;
Expand All @@ -613,6 +648,8 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
Expand Down Expand Up @@ -668,6 +705,8 @@ fn cat(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn embeddings(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
Expand All @@ -678,6 +717,8 @@ fn embeddings(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
Expand All @@ -690,6 +731,8 @@ fn cmp(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn index_select(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
Expand Down Expand Up @@ -744,6 +787,8 @@ fn index_select(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn index_add(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
Expand Down Expand Up @@ -787,6 +832,8 @@ fn index_add(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn slice_scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
Expand Down Expand Up @@ -829,6 +876,8 @@ fn slice_scatter(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
Expand Down Expand Up @@ -869,6 +918,8 @@ fn scatter_add(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
Expand Down Expand Up @@ -901,6 +952,8 @@ fn gather(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
Expand Down Expand Up @@ -950,6 +1003,8 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
Expand All @@ -969,6 +1024,8 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
Expand Down Expand Up @@ -1070,6 +1127,8 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
Expand All @@ -1078,61 +1137,6 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}

test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);

// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
Expand Down
26 changes: 26 additions & 0 deletions candle-test-macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "candle-test-macro"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true

[lib]
proc-macro = true

[features]

default = []

cuda = ["candle/cuda"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
syn={ version = "2.0", features = ["full"]}

[dev-dependencies]
candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" }
Loading

0 comments on commit d436831

Please sign in to comment.