Skip to content

Commit

Permalink
unit test compute_path_length across time ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 9, 2024
1 parent 4fade33 commit 05182d8
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from contextlib import nullcontext as does_not_raise

import numpy as np
import pytest
Expand Down Expand Up @@ -203,6 +204,59 @@ def test_approximate_derivative_with_invalid_order(order):
kinematics.compute_time_derivative(data, order=order)


@pytest.mark.parametrize(
"start, stop, expected_exception",
[
# full time ranges
(None, None, does_not_raise()),
(0, None, does_not_raise()),
(None, 9, does_not_raise()),
(0, 9, does_not_raise()),
# partial time ranges
(1, 8, does_not_raise()),
(1.5, 8.5, does_not_raise()),
(2, None, does_not_raise()),
(None, 6.3, does_not_raise()),
# invalid time ranges
(0, 10, pytest.raises(ValueError)), # stop > n_frames
(-1, 9, pytest.raises(ValueError)), # start < 0
(9, 0, pytest.raises(ValueError)), # start > stop
("text", 9, pytest.raises(TypeError)), # start is not a number
(0, [0, 1], pytest.raises(TypeError)), # stop is not a number
],
)
def test_compute_path_length_across_time_ranges(
valid_poses_dataset_uniform_linear_motion,
start,
stop,
expected_exception,
):
"""Test that the path length is computed correctly for a uniform linear
motion case.
"""
position = valid_poses_dataset_uniform_linear_motion.position
with expected_exception:
path_length = kinematics.compute_path_length(
position, start=start, stop=stop, nan_policy="scale"
)
# Expected number of steps (displacements) in selected time range
num_steps = 9 # full time range: 10 frames - 1
if start is not None:
num_steps -= np.ceil(start)
if stop is not None:
num_steps -= 9 - np.floor(stop)
# Each step has a magnitude of sqrt(2) in x-y space
expected_path_length = xr.DataArray(
np.ones((2, 3)) * np.sqrt(2) * num_steps,
dims=["individuals", "keypoints"],
coords={
"individuals": position.coords["individuals"],
"keypoints": position.coords["keypoints"],
},
)
xr.testing.assert_allclose(path_length, expected_path_length)


@pytest.fixture
def valid_data_array_for_forward_vector():
"""Return a position data array for an individual with 3 keypoints
Expand Down

0 comments on commit 05182d8

Please sign in to comment.