Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow max tracking args for Kalman filter #1986

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
37 changes: 29 additions & 8 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ class Tracker(BaseTracker):
max_tracking: bool = False # To enable maximum tracking.

cleaner: Optional[Callable] = None # TODO: deprecate
target_instance_count: int = 0
target_instance_count: int = 0 # TODO: deprecate
pre_cull_function: Optional[Callable] = None
post_connect_single_breaks: bool = False
robust_best_instance: float = 1.0
Expand Down Expand Up @@ -824,8 +824,15 @@ def final_pass(self, frames: List[LabeledFrame]):
# "tracking."
# )
self.cleaner.run(frames)
elif self.target_instance_count and self.post_connect_single_breaks:
elif (
self.target_instance_count or self.max_tracks
) and self.post_connect_single_breaks:
if not self.target_instance_count:
# If target_instance_count is not set, use max_tracks instead
# target_instance_count not available in the GUI
self.target_instance_count = self.max_tracks
connect_single_track_breaks(frames, self.target_instance_count)
print("Connecting single track breaks.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logging instead of print statements for better practice.

The print statement for status messages can be replaced with the logging module to provide configurable logging levels and outputs.

Apply this diff:

+import logging
...
+logger = logging.getLogger(__name__)
...
-print("Connecting single track breaks.")
+logger.info("Connecting single track breaks.")

Committable suggestion was skipped due to low confidence.


def get_name(self):
tracker_name = self.candidate_maker.__class__.__name__
Expand All @@ -850,7 +857,7 @@ def make_tracker_by_name(
of_max_levels: int = 3,
save_shifted_instances: bool = False,
# Pre-tracking options to cull instances
target_instance_count: int = 0,
target_instance_count: int = 0, # TODO: deprecate target_instance_count
pre_cull_to_target: bool = False,
pre_cull_iou_threshold: Optional[float] = None,
# Post-tracking options to connect broken tracks
Expand Down Expand Up @@ -921,6 +928,7 @@ def make_tracker_by_name(

pre_cull_function = None
if target_instance_count and pre_cull_to_target:
# Right now this is not accessible from the GUI

def pre_cull_function(inst_list):
cull_frame_instances(
Expand All @@ -940,11 +948,17 @@ def pre_cull_function(inst_list):
pre_cull_function=pre_cull_function,
max_tracking=max_tracking,
max_tracks=max_tracks,
target_instance_count=target_instance_count,
target_instance_count=target_instance_count, # TODO: deprecate target_instance_count
post_connect_single_breaks=post_connect_single_breaks,
)

if target_instance_count and kf_init_frame_count:
# Kalman filter requires deprecated target_instance_count
if (max_tracks or target_instance_count) and kf_init_frame_count:
if not target_instance_count:
# If target_instance_count is not set, use max_tracks instead
# target_instance_count not available in the GUI
target_instance_count = max_tracks
Comment on lines +955 to +960
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update Kalman filter implementation to use max_tracks instead of deprecated target_instance_count.

The Kalman filter initialization depends on target_instance_count, which is marked for deprecation. Modify the Kalman filter code to use max_tracks exclusively, ensuring future maintainability and consistency.


kalman_obj = KalmanTracker.make_tracker(
init_tracker=tracker_obj,
init_frame_count=kf_init_frame_count,
Expand All @@ -954,8 +968,10 @@ def pre_cull_function(inst_list):
)

return kalman_obj
elif kf_init_frame_count and not target_instance_count:
raise ValueError("Kalman filter requires target instance count.")
elif kf_init_frame_count and not (max_tracks or target_instance_count):
raise ValueError(
"Kalman filter requires max tracks or target instance count."
)
else:
return tracker_obj

Expand Down Expand Up @@ -1369,6 +1385,10 @@ def cull_function(inst_list):
if init_tracker.pre_cull_function is None:
init_tracker.pre_cull_function = cull_function

print(
f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters."
)

return cls(
init_tracker=init_tracker,
kalman_tracker=kalman_tracker,
Expand All @@ -1386,6 +1406,7 @@ def track(
untracked_instances: List[InstanceType],
img: Optional[np.ndarray] = None,
t: int = None,
**kwargs,
) -> List[InstanceType]:
"""Tracks individual frame, using Kalman filters if possible."""

Expand Down Expand Up @@ -1420,7 +1441,7 @@ def track(
# Initialize the Kalman filters
self.kalman_tracker.init_filters(self.init_set.instances)

# print(f"Kalman filters initialized (frame {t})")
print(f"Kalman filters initialized (frame {t})")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logging instead of print statements for initialization messages.

Replace the print statement with the logging module to improve flexibility and control over log outputs.

Apply this diff:

+import logging
...
+logger = logging.getLogger(__name__)
...
-print(f"Kalman filters initialized (frame {t})")
+logger.info(f"Kalman filters initialized (frame {t})")

Committable suggestion was skipped due to low confidence.


# Clear the data used to init filters, so that if the filters
# stop tracking and we need to re-init, we won't re-use the
Expand Down
Loading