Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default callable for match_lists_function #1520

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 69 additions & 112 deletions sleap/info/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,75 +10,6 @@
from sleap.io.dataset import Labels


def matched_instance_distances(
labels_gt: Labels,
labels_pr: Labels,
match_lists_function: Callable,
frame_range: Optional[range] = None,
) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]:

"""
Distances between ground truth and predicted nodes over a set of frames.

Args:
labels_gt: the `Labels` object with ground truth data
labels_pr: the `Labels` object with predicted data
match_lists_function: function for determining corresponding instances
Takes two lists of instances and returns "sorted" lists.
frame_range (optional): range of frames for which to compare data
If None, we compare every frame in labels_gt with corresponding
frame in labels_pr.
Returns:
Tuple:
* frame indices map: instance idx (for other matrices) -> frame idx
* distance matrix: (instances * nodes)
* ground truth points matrix: (instances * nodes * 2)
* predicted points matrix: (instances * nodes * 2)
"""

frame_idxs = []
points_gt = []
points_pr = []
for lf_gt in labels_gt.find(labels_gt.videos[0]):
frame_idx = lf_gt.frame_idx

# Get instances from ground truth/predicted labels
instances_gt = lf_gt.instances
lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)
if len(lfs_pr):
instances_pr = lfs_pr[0].instances
else:
instances_pr = []

# Sort ground truth and predicted instances.
# We'll then compare points between corresponding items in lists.
# We can use different "match" functions depending on what we want.
sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr)

# Convert lists of instances to (instances, nodes, 2) matrices.
# This allows match_lists_function to return data as either
# a list of Instances or a (instances, nodes, 2) matrix.
if type(sorted_gt[0]) != np.ndarray:
sorted_gt = list_points_array(sorted_gt)
if type(sorted_pr[0]) != np.ndarray:
sorted_pr = list_points_array(sorted_pr)

points_gt.append(sorted_gt)
points_pr.append(sorted_pr)
frame_idxs.extend([frame_idx] * len(sorted_gt))

# Convert arrays to numpy matrixes
# instances * nodes * (x,y)
points_gt = np.concatenate(points_gt)
points_pr = np.concatenate(points_pr)

# Calculate distances between corresponding nodes for all corresponding
# ground truth and predicted instances.
D = np.linalg.norm(points_gt - points_pr, axis=2)

return frame_idxs, D, points_gt, points_pr


def match_instance_lists(
instances_a: List[Union[Instance, PredictedInstance]],
instances_b: List[Union[Instance, PredictedInstance]],
Expand Down Expand Up @@ -165,6 +96,75 @@ def match_instance_lists_nodewise(
return instances_a, best_points_array


def matched_instance_distances(
labels_gt: Labels,
labels_pr: Labels,
match_lists_function: Callable = match_instance_lists_nodewise,
frame_range: Optional[range] = None,
) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]:

"""
Distances between ground truth and predicted nodes over a set of frames.

Args:
labels_gt: the `Labels` object with ground truth data
labels_pr: the `Labels` object with predicted data
match_lists_function: function for determining corresponding instances
Takes two lists of instances and returns "sorted" lists.
frame_range (optional): range of frames for which to compare data
If None, we compare every frame in labels_gt with corresponding
frame in labels_pr.
Returns:
Tuple:
* frame indices map: instance idx (for other matrices) -> frame idx
* distance matrix: (instances * nodes)
* ground truth points matrix: (instances * nodes * 2)
* predicted points matrix: (instances * nodes * 2)
"""

frame_idxs = []
points_gt = []
points_pr = []
for lf_gt in labels_gt.find(labels_gt.videos[0]):
frame_idx = lf_gt.frame_idx

# Get instances from ground truth/predicted labels
instances_gt = lf_gt.instances
lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)
if len(lfs_pr):
instances_pr = lfs_pr[0].instances
else:
instances_pr = []

# Sort ground truth and predicted instances.
# We'll then compare points between corresponding items in lists.
# We can use different "match" functions depending on what we want.
sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr)

# Convert lists of instances to (instances, nodes, 2) matrices.
# This allows match_lists_function to return data as either
# a list of Instances or a (instances, nodes, 2) matrix.
if type(sorted_gt[0]) != np.ndarray:
sorted_gt = list_points_array(sorted_gt)
if type(sorted_pr[0]) != np.ndarray:
sorted_pr = list_points_array(sorted_pr)

points_gt.append(sorted_gt)
points_pr.append(sorted_pr)
frame_idxs.extend([frame_idx] * len(sorted_gt))

# Convert arrays to numpy matrixes
# instances * nodes * (x,y)
points_gt = np.concatenate(points_gt)
points_pr = np.concatenate(points_pr)

# Calculate distances between corresponding nodes for all corresponding
# ground truth and predicted instances.
D = np.linalg.norm(points_gt - points_pr, axis=2)

return frame_idxs, D, points_gt, points_pr


def point_dist(
inst_a: Union[Instance, PredictedInstance],
inst_b: Union[Instance, PredictedInstance],
Expand Down Expand Up @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int:
def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int:
"""Given an array of distances, returns number which are not <= threshold."""
return dist_array.shape[0] - point_match_count(dist_array, thresh)


if __name__ == "__main__":

labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json")
labels_pr = Labels.load_json(
"tests/data/json_format_v2/centered_pair_predictions.json"
)

# OPTION 1

# Match each ground truth instance node to the closest corresponding node
# from any predicted instance in the same frame.

nodewise_matching_func = match_instance_lists_nodewise

# OPTION 2

# Match each ground truth instance to a distinct predicted instance:
# We want to maximize the number of "matching" points between instances,
# where "match" means the points are within some threshold distance.
# Note that each sorted list will be as long as the shorted input list.

instwise_matching_func = lambda gt_list, pr_list: match_instance_lists(
gt_list, pr_list, point_nonmatch_count
)

# PICK THE FUNCTION

inst_matching_func = nodewise_matching_func
# inst_matching_func = instwise_matching_func

# Calculate distances
frame_idxs, D, points_gt, points_pr = matched_instance_distances(
labels_gt, labels_pr, inst_matching_func
)

# Show mean difference for each node
node_names = labels_gt.skeletons[0].node_names

for node_idx, node_name in enumerate(node_names):
mean_d = np.nanmean(D[..., node_idx])
print(f"{node_name}\t\t{mean_d}")
55 changes: 55 additions & 0 deletions tests/info/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np

from sleap import Labels
from sleap.info.metrics import (
match_instance_lists_nodewise,
matched_instance_distances,
)


def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions):
labels_gt = centered_pair_labels
labels_pr = centered_pair_predictions

# Match each ground truth instance node to the closest corresponding node
# from any predicted instance in the same frame.

inst_matching_func = match_instance_lists_nodewise

# Calculate distances
frame_idxs, D, points_gt, points_pr = matched_instance_distances(
labels_gt, labels_pr, inst_matching_func
)

# Show mean difference for each node
node_names = labels_gt.skeletons[0].node_names
expected_values = {
"head": 0.872426920709296,
"neck": 0.8016280746914615,
"thorax": 0.8602021363390538,
"abdomen": 1.01012200038258,
"wingL": 1.1297727023475939,
"wingR": 1.0869857897008424,
"forelegL1": 0.780584225081443,
"forelegL2": 1.170805798894702,
"forelegL3": 1.1020486509389473,
"forelegR1": 0.9014698776116817,
"forelegR2": 0.9448001033112047,
"forelegR3": 1.308385214215777,
"midlegL1": 0.9095691623265347,
"midlegL2": 1.2203595627907582,
"midlegL3": 0.9813843358470163,
"midlegR1": 0.9871017182813739,
"midlegR2": 1.0209829335569256,
"midlegR3": 1.0990681234096988,
"hindlegL1": 1.0005335192834348,
"hindlegL2": 1.273539518539708,
"hindlegL3": 1.1752245985832817,
"hindlegR1": 1.1402833959265248,
"hindlegR2": 1.3143221301212737,
"hindlegR3": 1.0441458592503365,
}

for node_idx, node_name in enumerate(node_names):
mean_d = np.nanmean(D[..., node_idx])
assert np.isclose(mean_d, expected_values[node_name], atol=1e-6)
roomrys marked this conversation as resolved.
Show resolved Hide resolved
Loading