diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..83f994627 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -574,7 +574,7 @@ class Tracker(BaseTracker): max_tracking: bool = False # To enable maximum tracking. cleaner: Optional[Callable] = None # TODO: deprecate - target_instance_count: int = 0 + target_instance_count: int = 0 # TODO: deprecate pre_cull_function: Optional[Callable] = None post_connect_single_breaks: bool = False robust_best_instance: float = 1.0 @@ -824,8 +824,15 @@ def final_pass(self, frames: List[LabeledFrame]): # "tracking." # ) self.cleaner.run(frames) - elif self.target_instance_count and self.post_connect_single_breaks: + elif ( + self.target_instance_count or self.max_tracks + ) and self.post_connect_single_breaks: + if not self.target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + self.target_instance_count = self.max_tracks connect_single_track_breaks(frames, self.target_instance_count) + print("Connecting single track breaks.") def get_name(self): tracker_name = self.candidate_maker.__class__.__name__ @@ -850,7 +857,7 @@ def make_tracker_by_name( of_max_levels: int = 3, save_shifted_instances: bool = False, # Pre-tracking options to cull instances - target_instance_count: int = 0, + target_instance_count: int = 0, # TODO: deprecate target_instance_count pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, # Post-tracking options to connect broken tracks @@ -921,6 +928,7 @@ def make_tracker_by_name( pre_cull_function = None if target_instance_count and pre_cull_to_target: + # Right now this is not accessible from the GUI def pre_cull_function(inst_list): cull_frame_instances( @@ -940,11 +948,17 @@ def pre_cull_function(inst_list): pre_cull_function=pre_cull_function, max_tracking=max_tracking, max_tracks=max_tracks, - target_instance_count=target_instance_count, + target_instance_count=target_instance_count, # TODO: deprecate target_instance_count post_connect_single_breaks=post_connect_single_breaks, ) - if target_instance_count and kf_init_frame_count: + # Kalman filter requires deprecated target_instance_count + if (max_tracks or target_instance_count) and kf_init_frame_count: + if not target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + target_instance_count = max_tracks + kalman_obj = KalmanTracker.make_tracker( init_tracker=tracker_obj, init_frame_count=kf_init_frame_count, @@ -954,8 +968,10 @@ def pre_cull_function(inst_list): ) return kalman_obj - elif kf_init_frame_count and not target_instance_count: - raise ValueError("Kalman filter requires target instance count.") + elif kf_init_frame_count and not (max_tracks or target_instance_count): + raise ValueError( + "Kalman filter requires max tracks or target instance count." + ) else: return tracker_obj @@ -1369,6 +1385,10 @@ def cull_function(inst_list): if init_tracker.pre_cull_function is None: init_tracker.pre_cull_function = cull_function + print( + f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters." + ) + return cls( init_tracker=init_tracker, kalman_tracker=kalman_tracker, @@ -1386,6 +1406,7 @@ def track( untracked_instances: List[InstanceType], img: Optional[np.ndarray] = None, t: int = None, + **kwargs, ) -> List[InstanceType]: """Tracks individual frame, using Kalman filters if possible.""" @@ -1420,7 +1441,7 @@ def track( # Initialize the Kalman filters self.kalman_tracker.init_filters(self.init_set.instances) - # print(f"Kalman filters initialized (frame {t})") + print(f"Kalman filters initialized (frame {t})") # Clear the data used to init filters, so that if the filters # stop tracking and we need to re-init, we won't re-use the