Skip to content

Commit

Permalink
rewrote compute_path_length with various nan policies
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 9, 2024
1 parent 92eec53 commit 4fade33
Showing 1 changed file with 75 additions and 19 deletions.
94 changes: 75 additions & 19 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import xarray as xr

from movement.utils.logging import log_error
from movement.utils.logging import log_error, log_warning
from movement.utils.vector import compute_norm
from movement.validators.arrays import validate_dims_coords

Expand Down Expand Up @@ -199,26 +199,44 @@ def compute_path_length(
data: xr.DataArray,
start: float | None = None,
stop: float | None = None,
nan_policy: Literal["drop", "scale"] = "drop",
nan_warn_threshold: float = 0.2,
) -> xr.DataArray:
"""Compute the length of a path travelled between two time points.
The path length is defined as the sum of the norms (magnitudes) of the
displacement vectors between two time points ``start`` and ``stop``,
which should be provided in the time units of the data array.
If not specified, the minimum and maximum time points in the data array
are used as start and stop times, respectively.
If not specified, the minimum and maximum time coordinates of the data
array are used as start and stop times, respectively.
Parameters
----------
data : xarray.DataArray
The input data containing position information in Cartesian
coordinates, with ``time`` and ``space`` as dimensions.
coordinates, with ``time`` and ``space`` among the dimensions.
start : float, optional
The time point to consider as the start of a path.
If None (default), the minimum time point in the data is used.
The time to consider as the path's starting point. If None (default),
the minimum time coordinate in the data is used.
stop : float, optional
The time point to consider as the end of a path.
If None (default), the maximum time point in the data is used.
The time to consider as the path's end point. If None (default),
the maximum time coordinate in the data is used.
nan_policy : str, optional
Policy to handle NaN (missing) values. Can be one of the following:
- ``"drop"``: drop any NaN values before computing path length. This
is the default behavior, and it equates to assuming that a track
follows a straight line between two valid points flanking a missing
segment. This approach tends to underestimate the path length,
and the error increases with the number of missing values.
- ``"scale"``: scale path length based on the proportion of valid
values per point track. For example, if only 80% of the values are
present, the path length will be computed based on these values,
and the result will be multiplied by 1/0.8 = 1.25. This approach
assumes that the point's dynamics are similar across present
and missing time segments, which may not be the case.
nan_warn_threshold : float, optional
If more than this proportion of values are missing in any point track,
a warning will be emitted. Defaults to 0.2 (20%).
Returns
-------
Expand All @@ -228,21 +246,59 @@ def compute_path_length(
and ``space`` which will be removed.
"""
# We choose to validate the time dimension here, despite the fact that
# it will be also validated later in the compute_displacement function.
# This is because we rely on the time dimension for start and stop values.
# We validate the time dimension here, on top of its later validation
# inside compute_displacement, because we rely on it for start/stop times.
validate_dims_coords(data, {"time": []})
# Now validate the start and stop times
_validate_start_stop_times(data, start, stop)

# Handle the case where the start or stop times are not provided
start = data.time.min() if start is None else start
stop = data.time.max() if stop is None else stop
# Select data within the specified time range
data = data.sel(time=slice(start, stop))

# Emit a warning for point tracks with many missing values
nan_counts = data.isnull().any(dim="space").sum(dim="time")
dims_to_stack = [dim for dim in data.dims if dim not in ["time", "space"]]
# Stack individual and keypoints dims into a single 'tracks' dimension
stacked_nan_counts = nan_counts.stack(tracks=dims_to_stack)
tracks_with_warning = stacked_nan_counts.where(
stacked_nan_counts > nan_warn_threshold, drop=True
).tracks.values
if len(tracks_with_warning) > 0:
log_warning(
"The following point tracks have more than "
f"{nan_warn_threshold * 100}% missing values, which may lead to "
"unreliable path length estimates: "
f"{', '.join(tracks_with_warning)}."
)

if nan_policy == "drop":
stacked_data = data.stack(tracks=dims_to_stack)
# Create an empty data array to hold the path length for each track
stacked_path_length = xr.zeros_like(stacked_nan_counts)
# Compute path length for each track
for track_name in stacked_data.tracks:
track_data = stacked_data.sel(tracks=track_name, drop=True).dropna(
dim="time", how="any"
)
stacked_path_length.loc[track_name] = compute_norm(
compute_displacement(track_data)
).sum(dim="time")
# Return the unstacked path length (restore individual and keypoints)
return stacked_path_length.unstack("tracks")

elif nan_policy == "scale":
valid_path_length = compute_norm(compute_displacement(data)).sum(
dim="time",
skipna=True, # path length only for valid points
)
scale_factor = 1 / (1 - nan_counts / data.sizes["time"])
return valid_path_length * scale_factor

# Compute the sum of the displacement norms in the given time range
selected_data = data.sel(time=slice(start, stop))
displacement_norm = compute_norm(compute_displacement(selected_data))
return displacement_norm.sum(dim="time")
else:
raise log_error(
ValueError,
f"Invalid value for nan_policy: {nan_policy}. "
"Must be one of 'drop' or 'weight'.",
)


def compute_forward_vector(
Expand Down

0 comments on commit 4fade33

Please sign in to comment.