Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attribute test macro #1470

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-test-macro",
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
Expand Down
3 changes: 2 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.2", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.2", optional = true }
candle-test-macro = { path = "../candle-test-macro", version = "0.3.2" }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
Expand All @@ -38,7 +39,7 @@ criterion = { workspace = true }


[features]
default = []
default = ["cuda"]
cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
Expand Down
116 changes: 60 additions & 56 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
use candle_core::{test_utils, DType, Device, IndexOp, Result, Tensor, D};
use candle_test_macro::test_device;

#[test_device]
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
let (dim1, dim2) = tensor.dims2()?;
Expand All @@ -8,6 +10,7 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}

#[test_device]
fn ones(device: &Device) -> Result<()> {
assert_eq!(
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
Expand All @@ -32,6 +35,8 @@ fn ones(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn full(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
Expand All @@ -40,6 +45,8 @@ fn full(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
Expand All @@ -60,6 +67,8 @@ fn arange(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
Expand All @@ -75,6 +84,8 @@ fn add_mul(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn tensor_2d(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
Expand All @@ -85,6 +96,8 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn clamp(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
Expand All @@ -96,6 +109,8 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -144,6 +159,8 @@ fn unary_op(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
Expand Down Expand Up @@ -173,6 +190,8 @@ fn binary_op(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn transpose(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?.t()?;
Expand All @@ -188,6 +207,8 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
Expand All @@ -204,6 +225,8 @@ fn var(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -301,6 +324,8 @@ fn sum(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn min(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -363,6 +388,8 @@ fn min(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn max(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -425,6 +452,8 @@ fn max(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn argmin(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -499,6 +528,8 @@ fn argmin(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn argmax(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -573,6 +604,8 @@ fn argmax(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn narrow(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -603,6 +636,8 @@ fn narrow(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn broadcast(device: &Device) -> Result<()> {
let data = &[3f32, 1., 4.];
let tensor = Tensor::new(data, device)?;
Expand All @@ -613,6 +648,8 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
Expand Down Expand Up @@ -668,6 +705,8 @@ fn cat(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn embeddings(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
Expand All @@ -678,6 +717,8 @@ fn embeddings(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
Expand All @@ -690,6 +731,8 @@ fn cmp(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn index_select(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
Expand Down Expand Up @@ -744,6 +787,8 @@ fn index_select(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn index_add(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
Expand Down Expand Up @@ -787,6 +832,8 @@ fn index_add(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn slice_scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
Expand Down Expand Up @@ -829,6 +876,8 @@ fn slice_scatter(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
Expand Down Expand Up @@ -869,6 +918,8 @@ fn scatter_add(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
Expand Down Expand Up @@ -901,6 +952,8 @@ fn gather(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
Expand Down Expand Up @@ -950,6 +1003,8 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
Expand All @@ -969,6 +1024,8 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
Expand Down Expand Up @@ -1070,6 +1127,8 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}


#[test_device]
fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
Expand All @@ -1078,61 +1137,6 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}

test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);

// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
Expand Down
26 changes: 26 additions & 0 deletions candle-test-macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "candle-test-macro"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true

[lib]
proc-macro = true

[features]

default = []

cuda = ["candle/cuda"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
syn={ version = "2.0", features = ["full"]}

[dev-dependencies]
candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" }
Loading
Loading