From 7f04d496329787ee724304ba618dcbb109c80b2d Mon Sep 17 00:00:00 2001 From: b-peri Date: Tue, 10 Sep 2024 10:57:48 +0100 Subject: [PATCH] Extended testing and added `front_keypoint` argument to `compute_head_direction_vector()` --- movement/analysis/kinematics.py | 35 ++++++++++++++++++++++++++++-- tests/test_unit/test_kinematics.py | 35 ++++++++++++++++++++++-------- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index 524def04..2e93b4ca 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -4,6 +4,7 @@ import xarray as xr from movement.utils.logging import log_error +from movement.utils.vector import convert_to_unit def compute_displacement(data: xr.DataArray) -> xr.DataArray: @@ -163,13 +164,19 @@ def _compute_approximate_time_derivative( def compute_head_direction_vector( - data: xr.DataArray, left_keypoint: str, right_keypoint: str + data: xr.DataArray, + left_keypoint: str, + right_keypoint: str, + front_keypoint: str | None = None, ): """Compute the 2D head direction vector given two keypoints on the head. The head direction vector is computed as a vector perpendicular to the line connecting two keypoints on either side of the head, pointing - forwards (in the rostral direction). + forwards (in the rostral direction). As the forward direction may + differ between coordinate systems, the front keypoint is used ..., + when present. Otherwise, we assume that coordinates are given in the + image coordinate system (where the origin is located in the top-left). Parameters ---------- @@ -181,6 +188,8 @@ def compute_head_direction_vector( Name of the left keypoint, e.g., "left_ear" right_keypoint : str Name of the right keypoint, e.g., "right_ear" + front_keypoint : str | None + (Optional) Name of the front keypoint, e.g., "nose". Returns ------- @@ -191,6 +200,9 @@ def compute_head_direction_vector( """ # Validate input dataset + _validate_type_data_array(data) + _validate_time_keypoints_space_dimensions(data) + if left_keypoint == right_keypoint: raise log_error( ValueError, "The left and right keypoints may not be identical." @@ -217,6 +229,25 @@ def compute_head_direction_vector( :, :, :-1 ] + # Check computed head_vector is pointing in the same direction as vector + # from head midpoint to snout + if front_keypoint: + head_front = data.sel(keypoints=front_keypoint, drop=True) + head_midpoint = (head_right + head_left) / 2 + mid_to_front_vector = head_front - head_midpoint + dot_product_array = ( + convert_to_unit(head_vector.sel(individuals=data.individuals[0])) + * convert_to_unit(mid_to_front_vector).sel( + individuals=data.individuals[0] + ) + ).sum(dim="space") + median_dot_product = float(dot_product_array.median(dim="time").values) + if median_dot_product < 0: + perpendicular_vector = np.array([0, 0, 1]) + head_vector.values = np.cross( + right_to_left_vector, perpendicular_vector + )[:, :, :-1] + return head_vector diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 0795995f..d117314d 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -1,3 +1,6 @@ +import re +from contextlib import nullcontext as does_not_raise + import numpy as np import pytest import xarray as xr @@ -188,7 +191,7 @@ class TestNavigation: """Test suite for navigation-related functions in the kinematics module.""" @pytest.fixture - def mock_data_array(self): + def mock_dataarray(self): """Return a mock DataArray containing four known head orientations.""" time = [0, 1, 2, 3] individuals = ["individual_0"] @@ -213,7 +216,7 @@ def mock_data_array(self): return ds @pytest.fixture - def mock_data_array_3D(self): + def mock_dataarray_3D(self): """Return a 3D DataArray containing a known head orientation.""" time = [0] individuals = ["individual_0"] @@ -235,7 +238,7 @@ def mock_data_array_3D(self): return ds def test_compute_head_direction_vector( - self, mock_data_array, mock_data_array_3D + self, mock_dataarray, mock_dataarray_3D ): """Test that the correct head direction vectors are computed from a basic mock dataset. @@ -245,14 +248,14 @@ def test_compute_head_direction_vector( # Catch incorrect datatype with pytest.raises(TypeError, match="must be an xarray.DataArray"): kinematics.compute_head_direction_vector( - mock_data_array.values, "left_ear", "right_ear" + mock_dataarray.values, "left_ear", "right_ear" ) # Catch incorrect dimensions with pytest.raises( AttributeError, match="'time', 'space', and 'keypoints'" ): - mock_data_keypoint = mock_data_array.sel( + mock_data_keypoint = mock_dataarray.sel( keypoints="nose", drop=True ) kinematics.compute_head_direction_vector( @@ -262,20 +265,21 @@ def test_compute_head_direction_vector( # Catch identical left and right keypoints with pytest.raises(ValueError, match="keypoints may not be identical"): kinematics.compute_head_direction_vector( - mock_data_array, "left_ear", "left_ear" + mock_dataarray, "left_ear", "left_ear" ) # Catch incorrect spatial dimensions with pytest.raises( - ValueError, match="must have 2 (and only 2) spatial dimensions" + ValueError, + match=re.escape("must have 2 (and only 2) spatial dimensions"), ): kinematics.compute_head_direction_vector( - mock_data_array_3D, "left", "right" + mock_dataarray_3D, "left", "right" ) # Test that output contains correct datatype, dimensions, and values head_vector = kinematics.compute_head_direction_vector( - mock_data_array, "left_ear", "right_ear" + mock_dataarray, "left_ear", "right_ear" ) known_vectors = np.array([[[0, 2]], [[-2, 0]], [[0, -2]], [[2, 0]]]) @@ -285,3 +289,16 @@ def test_compute_head_direction_vector( and ("keypoints" not in head_vector.dims) ) assert np.equal(head_vector.values, known_vectors).all() + + # Test behaviour with NaNs + nan_dataarray = mock_dataarray.where( + (mock_dataarray.time != 1) + | (mock_dataarray.keypoints != "left_ear") + ) + head_vector = kinematics.compute_head_direction_vector( + nan_dataarray, "left_ear", "right_ear" + ) + assert ( + np.isnan(head_vector.values[1, 0, :]).all() + and not np.isnan(head_vector.values[[0, 2, 3], 0, :]).any() + )