diff --git a/keypoint_detection/models/detector.py b/keypoint_detection/models/detector.py index 5c4cf76..be8a413 100644 --- a/keypoint_detection/models/detector.py +++ b/keypoint_detection/models/detector.py @@ -120,6 +120,9 @@ def __init__( ] self.maximal_gt_keypoint_pixel_distances = maximal_gt_keypoint_pixel_distances + self.ap_training_metrics = [ + KeypointAPMetrics(self.maximal_gt_keypoint_pixel_distances) for _ in self.keypoint_channel_configuration + ] self.ap_validation_metrics = [ KeypointAPMetrics(self.maximal_gt_keypoint_pixel_distances) for _ in self.keypoint_channel_configuration ] @@ -251,8 +254,17 @@ def shared_step(self, batch, batch_idx, include_visualization_data_in_result_dic def training_step(self, train_batch, batch_idx): log_images = batch_idx == 0 and self.current_epoch > 0 + should_log_ap = ( + self.is_ap_epoch() + ) # and batch_idx < 20 # limit AP calculation to first 20 batches to save time + include_vis_data = log_images or should_log_ap + + result_dict = self.shared_step( + train_batch, batch_idx, include_visualization_data_in_result_dict=include_vis_data + ) - result_dict = self.shared_step(train_batch, batch_idx, include_visualization_data_in_result_dict=log_images) + if should_log_ap: + self.update_ap_metrics(result_dict, self.ap_training_metrics) if log_images: image_grids = self.visualize_predictions_channels(result_dict) @@ -340,20 +352,32 @@ def log_and_reset_mean_ap(self, mode: str): mean_ap_per_threshold = torch.zeros(len(self.maximal_gt_keypoint_pixel_distances)) metrics = self.ap_test_metrics if mode == "test" else self.ap_validation_metrics + # calculate APs for each channel and each threshold distance, and log them + print(f" # {mode} metrics:") for channel_idx, channel_name in enumerate(self.keypoint_channel_configuration): channel_aps = self.compute_and_log_metrics_for_channel(metrics[channel_idx], channel_name, mode) mean_ap_per_threshold += torch.tensor(channel_aps) + # calculate the mAP over all channels for each threshold distance, and log them for i, maximal_distance in enumerate(self.maximal_gt_keypoint_pixel_distances): self.log( f"{mode}/meanAP/d={float(maximal_distance):.1f}", mean_ap_per_threshold[i] / len(self.keypoint_channel_configuration), ) + # calculate the mAP over all channels and all threshold distances, and log it mean_ap = mean_ap_per_threshold.mean() / len(self.keypoint_channel_configuration) self.log(f"{mode}/meanAP", mean_ap) self.log(f"{mode}/meanAP/meanAP", mean_ap) + def training_epoch_end(self, outputs): + """ + Called on the end of a training epoch. + Used to compute and log the AP metrics. + """ + if self.is_ap_epoch(): + self.log_and_reset_mean_ap("train") + def validation_epoch_end(self, outputs): """ Called on the end of a validation epoch. @@ -396,18 +420,18 @@ def compute_and_log_metrics_for_channel( self, metrics: KeypointAPMetrics, channel: str, training_mode: str ) -> List[float]: """ - logs AP of predictions of single Channel for each threshold distance (as configured) for the categorization of the keypoints into a confusion matrix. - Also resets metric and returns resulting meanAP over all channels. + logs AP of predictions of single Channel for each threshold distance. + Also resets metric and returns resulting AP for all distances. """ - # compute ap's ap_metrics = metrics.compute() - print(f"{ap_metrics=}") + rounded_ap_metrics = {k: round(v, 3) for k, v in ap_metrics.items()} + print(f"{channel} : {rounded_ap_metrics}") for maximal_distance, ap in ap_metrics.items(): self.log(f"{training_mode}/{channel}_ap/d={float(maximal_distance):.1f}", ap) mean_ap = sum(ap_metrics.values()) / len(ap_metrics.values()) - self.log(f"{training_mode}/{channel}_ap/meanAP", mean_ap) # log top level for wandb hyperparam chart. + metrics.reset() return list(ap_metrics.values())