Skip to content

Commit

Permalink
add axis_windows_with_stride() method
Browse files Browse the repository at this point in the history
  • Loading branch information
goertzenator committed Dec 13, 2024
1 parent 4e61c87 commit c0961a6
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 5 deletions.
20 changes: 19 additions & 1 deletion src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,19 @@ where
/// ```
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
where S: Data
{
self.axis_windows_with_stride(axis, window_size, 1)
}

/// Returns a producer which traverses over windows of a given length and
/// stride along an axis.
///
/// Note that a calling this method with a stride of 1 is equivalent to
/// calling [`ArrayBase::axis_windows()`].
pub fn axis_windows_with_stride(
&self, axis: Axis, window_size: usize, stride_size: usize,
) -> AxisWindows<'_, A, D>
where S: Data
{
let axis_index = axis.index();

Expand All @@ -1507,7 +1520,12 @@ where
self.shape()
);

AxisWindows::new(self.view(), axis, window_size)
ndassert!(
stride_size >0,
"Stride size must be greater than zero"
);

AxisWindows::new_with_stride(self.view(), axis, window_size, stride_size)
}

// Return (length, stride) for diagonal
Expand Down
9 changes: 5 additions & 4 deletions src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub struct AxisWindows<'a, A, D>

impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
{
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize, stride_size: usize) -> Self
{
let window_strides = a.strides.clone();
let axis_idx = axis.index();
Expand All @@ -150,10 +150,11 @@ impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
window[axis_idx] = window_size;

let ndim = window.ndim();
let mut unit_stride = D::zeros(ndim);
unit_stride.slice_mut().fill(1);
let mut stride = D::zeros(ndim);
stride.slice_mut().fill(1);
stride[axis_idx] = stride_size;

let base = build_base(a, window.clone(), unit_stride);
let base = build_base(a, window.clone(), stride);
AxisWindows {
base,
axis_idx,
Expand Down
142 changes: 142 additions & 0 deletions tests/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,148 @@ fn tests_axis_windows_3d_zips_with_1d()
assert_eq!(b,arr1(&[207, 261]));
}

/// Test verifies that non existent Axis results in panic
#[test]
#[should_panic]
fn axis_windows_with_stride_outofbound()
{
let a = Array::from_iter(10..37)
.into_shape_with_order((3, 3, 3))
.unwrap();
a.axis_windows_with_stride(Axis(4), 2, 2);
}

/// Test verifies that zero sizes results in panic
#[test]
#[should_panic]
fn axis_windows_with_stride_zero_size()
{
let a = Array::from_iter(10..37)
.into_shape_with_order((3, 3, 3))
.unwrap();
a.axis_windows_with_stride(Axis(0), 0, 2);
}

/// Test verifies that zero stride results in panic
#[test]
#[should_panic]
fn axis_windows_with_stride_zero_stride()
{
let a = Array::from_iter(10..37)
.into_shape_with_order((3, 3, 3))
.unwrap();
a.axis_windows_with_stride(Axis(0), 2, 0);
}

/// Test verifies that over sized windows yield nothing
#[test]
fn axis_windows_with_stride_oversized()
{
let a = Array::from_iter(10..37)
.into_shape_with_order((3, 3, 3))
.unwrap();
let mut iter = a.axis_windows_with_stride(Axis(2), 4, 2).into_iter();
assert_eq!(iter.next(), None);
}

/// Simple test for iterating 1d-arrays via `Axis Windows`.
#[test]
fn test_axis_windows_with_stride_1d()
{
let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap();

itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 2), vec![
arr1(&[10, 11, 12, 13, 14]),
arr1(&[12, 13, 14, 15, 16]),
arr1(&[14, 15, 16, 17, 18]),
]);

itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 3), vec![
arr1(&[10, 11, 12, 13, 14]),
arr1(&[13, 14, 15, 16, 17]),
]);
}

/// Simple test for iterating 2d-arrays via `Axis Windows`.
#[test]
fn test_axis_windows_with_stride_2d()
{
let a = Array::from_iter(10..30)
.into_shape_with_order((5, 4))
.unwrap();

itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 1), vec![
arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]),
arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]),
arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]),
arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]),
]);

itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 2), vec![
arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]),
arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]),
]);

itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 3), vec![
arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]),
arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]),
]);
}

/// Simple test for iterating 3d-arrays via `Axis Windows`.
#[test]
fn test_axis_windows_with_stride_3d()
{
let a = Array::from_iter(0..27)
.into_shape_with_order((3, 3, 3))
.unwrap();

itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 1), vec![
arr3(&[
[[0, 1, 2], [3, 4, 5]],
[[9, 10, 11], [12, 13, 14]],
[[18, 19, 20], [21, 22, 23]],
]),
arr3(&[
[[3, 4, 5], [6, 7, 8]],
[[12, 13, 14], [15, 16, 17]],
[[21, 22, 23], [24, 25, 26]],
]),
]);

itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 2), vec![
arr3(&[
[[0, 1, 2], [3, 4, 5]],
[[9, 10, 11], [12, 13, 14]],
[[18, 19, 20], [21, 22, 23]],
]),
]);
}

#[test]
fn tests_axis_windows_with_stride_3d_zips_with_1d()
{
let a = Array::from_iter(0..27)
.into_shape_with_order((3, 3, 3))
.unwrap();
let mut b1 = Array::zeros(2);
let mut b2 = Array::zeros(1);

Zip::from(b1.view_mut())
.and(a.axis_windows_with_stride(Axis(1), 2, 1))
.for_each(|b, a| {
*b = a.sum();
});
assert_eq!(b1,arr1(&[207, 261]));

Zip::from(b2.view_mut())
.and(a.axis_windows_with_stride(Axis(1), 2, 2))
.for_each(|b, a| {
*b = a.sum();
});
assert_eq!(b2,arr1(&[207]));
}

#[test]
fn test_window_neg_stride()
{
Expand Down

0 comments on commit c0961a6

Please sign in to comment.