diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index 2ac61d339..5bec077e4 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -10,75 +10,6 @@ from sleap.io.dataset import Labels -def matched_instance_distances( - labels_gt: Labels, - labels_pr: Labels, - match_lists_function: Callable, - frame_range: Optional[range] = None, -) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: - - """ - Distances between ground truth and predicted nodes over a set of frames. - - Args: - labels_gt: the `Labels` object with ground truth data - labels_pr: the `Labels` object with predicted data - match_lists_function: function for determining corresponding instances - Takes two lists of instances and returns "sorted" lists. - frame_range (optional): range of frames for which to compare data - If None, we compare every frame in labels_gt with corresponding - frame in labels_pr. - Returns: - Tuple: - * frame indices map: instance idx (for other matrices) -> frame idx - * distance matrix: (instances * nodes) - * ground truth points matrix: (instances * nodes * 2) - * predicted points matrix: (instances * nodes * 2) - """ - - frame_idxs = [] - points_gt = [] - points_pr = [] - for lf_gt in labels_gt.find(labels_gt.videos[0]): - frame_idx = lf_gt.frame_idx - - # Get instances from ground truth/predicted labels - instances_gt = lf_gt.instances - lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) - if len(lfs_pr): - instances_pr = lfs_pr[0].instances - else: - instances_pr = [] - - # Sort ground truth and predicted instances. - # We'll then compare points between corresponding items in lists. - # We can use different "match" functions depending on what we want. - sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) - - # Convert lists of instances to (instances, nodes, 2) matrices. - # This allows match_lists_function to return data as either - # a list of Instances or a (instances, nodes, 2) matrix. - if type(sorted_gt[0]) != np.ndarray: - sorted_gt = list_points_array(sorted_gt) - if type(sorted_pr[0]) != np.ndarray: - sorted_pr = list_points_array(sorted_pr) - - points_gt.append(sorted_gt) - points_pr.append(sorted_pr) - frame_idxs.extend([frame_idx] * len(sorted_gt)) - - # Convert arrays to numpy matrixes - # instances * nodes * (x,y) - points_gt = np.concatenate(points_gt) - points_pr = np.concatenate(points_pr) - - # Calculate distances between corresponding nodes for all corresponding - # ground truth and predicted instances. - D = np.linalg.norm(points_gt - points_pr, axis=2) - - return frame_idxs, D, points_gt, points_pr - - def match_instance_lists( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], @@ -165,6 +96,75 @@ def match_instance_lists_nodewise( return instances_a, best_points_array +def matched_instance_distances( + labels_gt: Labels, + labels_pr: Labels, + match_lists_function: Callable = match_instance_lists_nodewise, + frame_range: Optional[range] = None, +) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: + + """ + Distances between ground truth and predicted nodes over a set of frames. + + Args: + labels_gt: the `Labels` object with ground truth data + labels_pr: the `Labels` object with predicted data + match_lists_function: function for determining corresponding instances + Takes two lists of instances and returns "sorted" lists. + frame_range (optional): range of frames for which to compare data + If None, we compare every frame in labels_gt with corresponding + frame in labels_pr. + Returns: + Tuple: + * frame indices map: instance idx (for other matrices) -> frame idx + * distance matrix: (instances * nodes) + * ground truth points matrix: (instances * nodes * 2) + * predicted points matrix: (instances * nodes * 2) + """ + + frame_idxs = [] + points_gt = [] + points_pr = [] + for lf_gt in labels_gt.find(labels_gt.videos[0]): + frame_idx = lf_gt.frame_idx + + # Get instances from ground truth/predicted labels + instances_gt = lf_gt.instances + lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) + if len(lfs_pr): + instances_pr = lfs_pr[0].instances + else: + instances_pr = [] + + # Sort ground truth and predicted instances. + # We'll then compare points between corresponding items in lists. + # We can use different "match" functions depending on what we want. + sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) + + # Convert lists of instances to (instances, nodes, 2) matrices. + # This allows match_lists_function to return data as either + # a list of Instances or a (instances, nodes, 2) matrix. + if type(sorted_gt[0]) != np.ndarray: + sorted_gt = list_points_array(sorted_gt) + if type(sorted_pr[0]) != np.ndarray: + sorted_pr = list_points_array(sorted_pr) + + points_gt.append(sorted_gt) + points_pr.append(sorted_pr) + frame_idxs.extend([frame_idx] * len(sorted_gt)) + + # Convert arrays to numpy matrixes + # instances * nodes * (x,y) + points_gt = np.concatenate(points_gt) + points_pr = np.concatenate(points_pr) + + # Calculate distances between corresponding nodes for all corresponding + # ground truth and predicted instances. + D = np.linalg.norm(points_gt - points_pr, axis=2) + + return frame_idxs, D, points_gt, points_pr + + def point_dist( inst_a: Union[Instance, PredictedInstance], inst_b: Union[Instance, PredictedInstance], @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh) - - -if __name__ == "__main__": - - labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") - labels_pr = Labels.load_json( - "tests/data/json_format_v2/centered_pair_predictions.json" - ) - - # OPTION 1 - - # Match each ground truth instance node to the closest corresponding node - # from any predicted instance in the same frame. - - nodewise_matching_func = match_instance_lists_nodewise - - # OPTION 2 - - # Match each ground truth instance to a distinct predicted instance: - # We want to maximize the number of "matching" points between instances, - # where "match" means the points are within some threshold distance. - # Note that each sorted list will be as long as the shorted input list. - - instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( - gt_list, pr_list, point_nonmatch_count - ) - - # PICK THE FUNCTION - - inst_matching_func = nodewise_matching_func - # inst_matching_func = instwise_matching_func - - # Calculate distances - frame_idxs, D, points_gt, points_pr = matched_instance_distances( - labels_gt, labels_pr, inst_matching_func - ) - - # Show mean difference for each node - node_names = labels_gt.skeletons[0].node_names - - for node_idx, node_name in enumerate(node_names): - mean_d = np.nanmean(D[..., node_idx]) - print(f"{node_name}\t\t{mean_d}") diff --git a/tests/info/test_metrics.py b/tests/info/test_metrics.py new file mode 100644 index 000000000..0d2e097e6 --- /dev/null +++ b/tests/info/test_metrics.py @@ -0,0 +1,55 @@ +import numpy as np + +from sleap import Labels +from sleap.info.metrics import ( + match_instance_lists_nodewise, + matched_instance_distances, +) + + +def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions): + labels_gt = centered_pair_labels + labels_pr = centered_pair_predictions + + # Match each ground truth instance node to the closest corresponding node + # from any predicted instance in the same frame. + + inst_matching_func = match_instance_lists_nodewise + + # Calculate distances + frame_idxs, D, points_gt, points_pr = matched_instance_distances( + labels_gt, labels_pr, inst_matching_func + ) + + # Show mean difference for each node + node_names = labels_gt.skeletons[0].node_names + expected_values = { + "head": 0.872426920709296, + "neck": 0.8016280746914615, + "thorax": 0.8602021363390538, + "abdomen": 1.01012200038258, + "wingL": 1.1297727023475939, + "wingR": 1.0869857897008424, + "forelegL1": 0.780584225081443, + "forelegL2": 1.170805798894702, + "forelegL3": 1.1020486509389473, + "forelegR1": 0.9014698776116817, + "forelegR2": 0.9448001033112047, + "forelegR3": 1.308385214215777, + "midlegL1": 0.9095691623265347, + "midlegL2": 1.2203595627907582, + "midlegL3": 0.9813843358470163, + "midlegR1": 0.9871017182813739, + "midlegR2": 1.0209829335569256, + "midlegR3": 1.0990681234096988, + "hindlegL1": 1.0005335192834348, + "hindlegL2": 1.273539518539708, + "hindlegL3": 1.1752245985832817, + "hindlegR1": 1.1402833959265248, + "hindlegR2": 1.3143221301212737, + "hindlegR3": 1.0441458592503365, + } + + for node_idx, node_name in enumerate(node_names): + mean_d = np.nanmean(D[..., node_idx]) + assert np.isclose(mean_d, expected_values[node_name], atol=1e-6)