From 230e6f78a826074a0f7940339b23c479dcd52f7e Mon Sep 17 00:00:00 2001 From: Jonas Kantic Date: Sun, 5 May 2024 16:16:14 +0200 Subject: [PATCH 1/3] Adds remainder ops implementation for Tensor. --- crates/burn-tensor/src/tensor/api/numeric.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 40674d6d64..705d3b0072 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2868,6 +2868,20 @@ where } } +impl core::ops::Rem for Tensor +where + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, +{ + type Output = Self; + + fn rem(self, other: E) -> Self { + Tensor::remainder_scalar(self, other) + } +} + impl core::ops::Mul> for Tensor where B: Backend, From f7959361b755a905bb54fa2aca54fa5373e7308d Mon Sep 17 00:00:00 2001 From: Jonas Kantic Date: Sun, 5 May 2024 16:56:39 +0200 Subject: [PATCH 2/3] Adds test for % operator. --- crates/burn-tensor/src/tests/ops/remainder.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/burn-tensor/src/tests/ops/remainder.rs b/crates/burn-tensor/src/tests/ops/remainder.rs index e3f86cd0f5..6e69f5889e 100644 --- a/crates/burn-tensor/src/tests/ops/remainder.rs +++ b/crates/burn-tensor/src/tests/ops/remainder.rs @@ -95,4 +95,17 @@ mod tests { let data_expected = Data::from([9.0, 1.0]); data_expected.assert_approx_eq(&data_actual, 3); } + + #[test] + fn should_support_remainder_op() { + let data = Data::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); + let device = Default::default(); + let tensor = Tensor::::from_data(data, &device); + + let output = tensor % 2.0; + + let data_actual = output.into_data(); + let data_expected = Data::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); + data_expected.assert_approx_eq(&data_actual, 3); + } } From c0e545912ca0c57d9fb8c83e01c40c1fd6ded58a Mon Sep 17 00:00:00 2001 From: Zirconium419122 Date: Wed, 29 May 2024 15:18:42 +0200 Subject: [PATCH 3/3] Add remainder and % operator entry in tensor.md --- burn-book/src/building-blocks/tensor.md | 1 + 1 file changed, 1 insertion(+) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index e3d10658fc..1ec4290598 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -228,6 +228,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | | `tensor.prod()` | `tensor.prod()` | | `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` | +| `tensor.rem(other)` or `tensor % other` | `tensor % other` | | `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` | | `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` | | `tensor.select_assign(dim, indices, values)` | N/A |