Skip to content

Commit

Permalink
Set default callable for match_lists_function (#1520)
Browse files Browse the repository at this point in the history
* Set default for `match_lists_function`

* Move test code to official tests

* Check using expected values
  • Loading branch information
roomrys authored Sep 29, 2023
1 parent 4100153 commit ed77b49
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 112 deletions.
181 changes: 69 additions & 112 deletions sleap/info/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,75 +10,6 @@
from sleap.io.dataset import Labels


def matched_instance_distances(
labels_gt: Labels,
labels_pr: Labels,
match_lists_function: Callable,
frame_range: Optional[range] = None,
) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]:

"""
Distances between ground truth and predicted nodes over a set of frames.
Args:
labels_gt: the `Labels` object with ground truth data
labels_pr: the `Labels` object with predicted data
match_lists_function: function for determining corresponding instances
Takes two lists of instances and returns "sorted" lists.
frame_range (optional): range of frames for which to compare data
If None, we compare every frame in labels_gt with corresponding
frame in labels_pr.
Returns:
Tuple:
* frame indices map: instance idx (for other matrices) -> frame idx
* distance matrix: (instances * nodes)
* ground truth points matrix: (instances * nodes * 2)
* predicted points matrix: (instances * nodes * 2)
"""

frame_idxs = []
points_gt = []
points_pr = []
for lf_gt in labels_gt.find(labels_gt.videos[0]):
frame_idx = lf_gt.frame_idx

# Get instances from ground truth/predicted labels
instances_gt = lf_gt.instances
lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)
if len(lfs_pr):
instances_pr = lfs_pr[0].instances
else:
instances_pr = []

# Sort ground truth and predicted instances.
# We'll then compare points between corresponding items in lists.
# We can use different "match" functions depending on what we want.
sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr)

# Convert lists of instances to (instances, nodes, 2) matrices.
# This allows match_lists_function to return data as either
# a list of Instances or a (instances, nodes, 2) matrix.
if type(sorted_gt[0]) != np.ndarray:
sorted_gt = list_points_array(sorted_gt)
if type(sorted_pr[0]) != np.ndarray:
sorted_pr = list_points_array(sorted_pr)

points_gt.append(sorted_gt)
points_pr.append(sorted_pr)
frame_idxs.extend([frame_idx] * len(sorted_gt))

# Convert arrays to numpy matrixes
# instances * nodes * (x,y)
points_gt = np.concatenate(points_gt)
points_pr = np.concatenate(points_pr)

# Calculate distances between corresponding nodes for all corresponding
# ground truth and predicted instances.
D = np.linalg.norm(points_gt - points_pr, axis=2)

return frame_idxs, D, points_gt, points_pr


def match_instance_lists(
instances_a: List[Union[Instance, PredictedInstance]],
instances_b: List[Union[Instance, PredictedInstance]],
Expand Down Expand Up @@ -165,6 +96,75 @@ def match_instance_lists_nodewise(
return instances_a, best_points_array


def matched_instance_distances(
labels_gt: Labels,
labels_pr: Labels,
match_lists_function: Callable = match_instance_lists_nodewise,
frame_range: Optional[range] = None,
) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]:

"""
Distances between ground truth and predicted nodes over a set of frames.
Args:
labels_gt: the `Labels` object with ground truth data
labels_pr: the `Labels` object with predicted data
match_lists_function: function for determining corresponding instances
Takes two lists of instances and returns "sorted" lists.
frame_range (optional): range of frames for which to compare data
If None, we compare every frame in labels_gt with corresponding
frame in labels_pr.
Returns:
Tuple:
* frame indices map: instance idx (for other matrices) -> frame idx
* distance matrix: (instances * nodes)
* ground truth points matrix: (instances * nodes * 2)
* predicted points matrix: (instances * nodes * 2)
"""

frame_idxs = []
points_gt = []
points_pr = []
for lf_gt in labels_gt.find(labels_gt.videos[0]):
frame_idx = lf_gt.frame_idx

# Get instances from ground truth/predicted labels
instances_gt = lf_gt.instances
lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)
if len(lfs_pr):
instances_pr = lfs_pr[0].instances
else:
instances_pr = []

# Sort ground truth and predicted instances.
# We'll then compare points between corresponding items in lists.
# We can use different "match" functions depending on what we want.
sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr)

# Convert lists of instances to (instances, nodes, 2) matrices.
# This allows match_lists_function to return data as either
# a list of Instances or a (instances, nodes, 2) matrix.
if type(sorted_gt[0]) != np.ndarray:
sorted_gt = list_points_array(sorted_gt)
if type(sorted_pr[0]) != np.ndarray:
sorted_pr = list_points_array(sorted_pr)

points_gt.append(sorted_gt)
points_pr.append(sorted_pr)
frame_idxs.extend([frame_idx] * len(sorted_gt))

# Convert arrays to numpy matrixes
# instances * nodes * (x,y)
points_gt = np.concatenate(points_gt)
points_pr = np.concatenate(points_pr)

# Calculate distances between corresponding nodes for all corresponding
# ground truth and predicted instances.
D = np.linalg.norm(points_gt - points_pr, axis=2)

return frame_idxs, D, points_gt, points_pr


def point_dist(
inst_a: Union[Instance, PredictedInstance],
inst_b: Union[Instance, PredictedInstance],
Expand Down Expand Up @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int:
def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int:
"""Given an array of distances, returns number which are not <= threshold."""
return dist_array.shape[0] - point_match_count(dist_array, thresh)


if __name__ == "__main__":

labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json")
labels_pr = Labels.load_json(
"tests/data/json_format_v2/centered_pair_predictions.json"
)

# OPTION 1

# Match each ground truth instance node to the closest corresponding node
# from any predicted instance in the same frame.

nodewise_matching_func = match_instance_lists_nodewise

# OPTION 2

# Match each ground truth instance to a distinct predicted instance:
# We want to maximize the number of "matching" points between instances,
# where "match" means the points are within some threshold distance.
# Note that each sorted list will be as long as the shorted input list.

instwise_matching_func = lambda gt_list, pr_list: match_instance_lists(
gt_list, pr_list, point_nonmatch_count
)

# PICK THE FUNCTION

inst_matching_func = nodewise_matching_func
# inst_matching_func = instwise_matching_func

# Calculate distances
frame_idxs, D, points_gt, points_pr = matched_instance_distances(
labels_gt, labels_pr, inst_matching_func
)

# Show mean difference for each node
node_names = labels_gt.skeletons[0].node_names

for node_idx, node_name in enumerate(node_names):
mean_d = np.nanmean(D[..., node_idx])
print(f"{node_name}\t\t{mean_d}")
55 changes: 55 additions & 0 deletions tests/info/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np

from sleap import Labels
from sleap.info.metrics import (
match_instance_lists_nodewise,
matched_instance_distances,
)


def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions):
labels_gt = centered_pair_labels
labels_pr = centered_pair_predictions

# Match each ground truth instance node to the closest corresponding node
# from any predicted instance in the same frame.

inst_matching_func = match_instance_lists_nodewise

# Calculate distances
frame_idxs, D, points_gt, points_pr = matched_instance_distances(
labels_gt, labels_pr, inst_matching_func
)

# Show mean difference for each node
node_names = labels_gt.skeletons[0].node_names
expected_values = {
"head": 0.872426920709296,
"neck": 0.8016280746914615,
"thorax": 0.8602021363390538,
"abdomen": 1.01012200038258,
"wingL": 1.1297727023475939,
"wingR": 1.0869857897008424,
"forelegL1": 0.780584225081443,
"forelegL2": 1.170805798894702,
"forelegL3": 1.1020486509389473,
"forelegR1": 0.9014698776116817,
"forelegR2": 0.9448001033112047,
"forelegR3": 1.308385214215777,
"midlegL1": 0.9095691623265347,
"midlegL2": 1.2203595627907582,
"midlegL3": 0.9813843358470163,
"midlegR1": 0.9871017182813739,
"midlegR2": 1.0209829335569256,
"midlegR3": 1.0990681234096988,
"hindlegL1": 1.0005335192834348,
"hindlegL2": 1.273539518539708,
"hindlegL3": 1.1752245985832817,
"hindlegR1": 1.1402833959265248,
"hindlegR2": 1.3143221301212737,
"hindlegR3": 1.0441458592503365,
}

for node_idx, node_name in enumerate(node_names):
mean_d = np.nanmean(D[..., node_idx])
assert np.isclose(mean_d, expected_values[node_name], atol=1e-6)

0 comments on commit ed77b49

Please sign in to comment.