Skip to content

Commit

Permalink
Handle scalar and 1d dims
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Sep 9, 2024
1 parent a20d283 commit f35d96e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
7 changes: 6 additions & 1 deletion movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def cdist(
core_dim = "individuals" if dim == "keypoints" else "keypoints"
elem1 = getattr(a, dim).item()
elem2 = getattr(b, dim).item()
if a.coords[core_dim].ndim == 0:
a = a.expand_dims(core_dim).transpose("time", "space", core_dim)
if b.coords[core_dim].ndim == 0:
b = b.expand_dims(core_dim).transpose("time", "space", core_dim)
result = xr.apply_ufunc(
_cdist,
a,
Expand All @@ -235,7 +239,8 @@ def cdist(
elem2: getattr(a, core_dim).values,
}
)
return result
# Drop any squeezed coordinates
return result.squeeze(drop=True)


def compute_interindividual_distances(
Expand Down
54 changes: 53 additions & 1 deletion tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from contextlib import nullcontext as does_not_raise

import numpy as np
import pytest
Expand Down Expand Up @@ -227,7 +228,9 @@ def test_approximate_derivative_with_invalid_order(order):
),
],
)
def test_cdist(dim, pairs, expected_data, pairwise_distances_dataset):
def test_cdist_with_known_values(
dim, pairs, expected_data, pairwise_distances_dataset
):
"""Test the computation of pairwise distances with known values."""
core_dim = "keypoints" if dim == "individuals" else "individuals"
input_dataarray = pairwise_distances_dataset.position
Expand All @@ -249,6 +252,55 @@ def test_cdist(dim, pairs, expected_data, pairwise_distances_dataset):
)


@pytest.mark.parametrize(
"selection_fn",
[
lambda position: (
position.sel(individuals="ind1"),
position.sel(individuals="ind2"),
), # individuals dim is scalar
lambda position: (
position.where(
position.individuals == "ind1", drop=True
).squeeze(),
position.where(
position.individuals == "ind2", drop=True
).squeeze(),
), # individuals dim is 1D
lambda position: (
position.sel(individuals="ind1", keypoints="key1"),
position.sel(individuals="ind2", keypoints="key1"),
), # both individuals and keypoints dims are scalar
lambda position: (
position.where(position.keypoints == "key1", drop=True).sel(
individuals="ind1"
),
position.where(position.keypoints == "key1", drop=True).sel(
individuals="ind2"
),
), # keypoints dim is 1D
],
ids=[
"dim_has_ndim_0",
"dim_has_ndim_1",
"core_dim_has_ndim_0",
"core_dim_has_ndim_1",
],
)
def test_cdist_with_single_dim_inputs(
pairwise_distances_dataset, selection_fn
):
"""Test that the computation of pairwise distances
works regardless of whether the input DataArrays have
```dim``` and ```core_dim``` being either scalar (ndim=0)
or 1D (ndim=1).
"""
position = pairwise_distances_dataset.position
a, b = selection_fn(position)
with does_not_raise():
kinematics.cdist(a, b, "individuals")


def expected_pairwise_distances(pairs, input_ds, dim):
"""Return a list of the expected data variable names
for pairwise distances tests.
Expand Down

0 comments on commit f35d96e

Please sign in to comment.