Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed-up labeling suggestions look-up #709

Merged
merged 12 commits into from
Apr 25, 2022
8 changes: 5 additions & 3 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import sleap
from sleap.gui.dialogs.metrics import MetricsTableDialog
from sleap.skeleton import Skeleton
from sleap.instance import Instance
from sleap.instance import Instance, LabeledFrame
from sleap.io.dataset import Labels
from sleap.info.summary import StatisticSeries
from sleap.gui.commands import CommandContext, UpdateTopic
Expand Down Expand Up @@ -265,7 +265,7 @@ def dropEvent(self, event):
self.commands.showImportVideos(filenames=filenames)

@property
def labels(self):
def labels(self) -> Labels:
return self.state["labels"]

@labels.setter
Expand Down Expand Up @@ -1280,7 +1280,9 @@ def _has_topic(topic_list):
if suggestion_list:
labeled_count = 0
for suggestion in suggestion_list:
lf = self.labels.get((suggestion.video, suggestion.frame_idx))
lf = self.labels.get(
(suggestion.video, suggestion.frame_idx), use_cache=True
)
roomrys marked this conversation as resolved.
Show resolved Hide resolved
if lf is not None and lf.has_user_instances:
labeled_count += 1
prc = (labeled_count / len(suggestion_list)) * 100
Expand Down
2 changes: 1 addition & 1 deletion sleap/gui/dataviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def item_to_data(self, obj, item):
item_dict["frame"] = int(item.frame_idx) + 1 # start at frame 1 rather than 0

# show how many labeled instances are in this frame
lf = labels.get((item.video, item.frame_idx))
lf = labels.get((item.video, item.frame_idx), use_cache=True)
roomrys marked this conversation as resolved.
Show resolved Hide resolved
val = 0 if lf is None else len(lf.user_instances)
val = str(val) if val > 0 else ""
item_dict["labeled"] = val
Expand Down
201 changes: 140 additions & 61 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,75 +629,147 @@ def __contains__(self, item) -> bool:
return self.find_first(item[0], item[1].tolist()) is not None
raise ValueError("Item is not an object type contained in labels.")

def __getitem__(self, key, *args) -> Union[LabeledFrame, List[LabeledFrame]]:
"""Return labeled frames matching key.
def __getitem__(
self,
key: Union[
int,
slice,
np.integer,
np.ndarray,
list,
range,
Video,
Tuple[Video, Union[np.integer, np.ndarray, int, list, range]],
],
*secondary_key: Union[
int,
slice,
np.integer,
np.ndarray,
list,
range,
],
) -> Union[LabeledFrame, List[LabeledFrame]]:
roomrys marked this conversation as resolved.
Show resolved Hide resolved
"""Return labeled frames matching key or return `None` if not found.

This makes `labels[...]` safe and will not raise an exception if the
item is not found.

Do not call __getitem__ directly, use get instead (get allows kwargs for logic).
If you happen to call __getitem__ directly, get will be called but without any
keyword arguments.

Args:
key: Indexing argument to match against. If `key` is a `Video` or tuple of
`(Video, frame_index)`, frames that match the criteria will be searched
for. If a scalar, list, range or array of integers are provided, the
labels with those linear indices will be returned.
secondary_key: Numerical indexing argument(s) which supplement `key`. Only
used when `key` is a `Video`.
"""
return self.get(key, *secondary_key)

def get(
self,
key: Union[
int,
slice,
np.integer,
np.ndarray,
list,
range,
Video,
Tuple[Video, Union[np.integer, np.ndarray, int, list, range]],
],
*secondary_key: Union[
int,
slice,
np.integer,
np.ndarray,
list,
range,
],
use_cache: bool = False,
raise_errors: bool = False,
) -> Union[LabeledFrame, List[LabeledFrame]]:
"""Return labeled frames matching key or return `None` if not found.

This is a safe version of `labels[...]` that will not raise an exception if the
item is not found.

Args:
key: Indexing argument to match against. If `key` is a `Video` or tuple of
`(Video, frame_index)`, frames that match the criteria will be searched
for. If a scalar, list, range or array of integers are provided, the
labels with those linear indices will be returned.
secondary_key: Numerical indexing argument(s) which supplement `key`. Only
used when `key` is of type `Video`.
use_cache: Boolean that determines whether Labels.find_first() should
instead instead call Labels.find() which uses the labels data cache. If
True, use the labels data cache, else loop through all labels to search.
raise_errors: Boolean that determines whether KeyErrors should be raised. If
True, raises KeyErrors, else catches KeyErrors and returns None instead
of raising KeyError.

Raises:
KeyError: If the specified key could not be found.

Returns:
A list with the matching `LabeledFrame`s, or a single `LabeledFrame` if a
scalar key was provided.
scalar key was provided, or `None` if not found.
"""
if len(args) > 0:
if type(key) != tuple:
key = (key,)
key = key + tuple(args)

if isinstance(key, int):
return self.labels.__getitem__(key)

elif isinstance(key, Video):
if key not in self.videos:
raise KeyError("Video not found in labels.")
return self.find(video=key)

elif isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], Video):
if key[0] not in self.videos:
raise KeyError("Video not found in labels.")

if isinstance(key[1], int):
_hit = self.find_first(video=key[0], frame_idx=key[1])
if _hit is None:
raise KeyError(
f"No label found for specified video at frame {key[1]}."
try:
if len(secondary_key) > 0:
if type(key) != tuple:
key = (key,)
key = key + tuple(secondary_key)

# Do any conversions first.
if isinstance(key, slice):
start, stop, step = key.indices(len(self))
key = range(start, stop, step)
elif isinstance(key, (np.integer, np.ndarray)):
key = key.tolist()

if isinstance(key, int):
return self.labels.__getitem__(key)

elif isinstance(key, Video):
if key not in self.videos:
raise KeyError("Video not found in labels.")
return self.find(video=key)

elif isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], Video):
if key[0] not in self.videos:
raise KeyError("Video not found in labels.")

# Do any conversions first.
if isinstance(key[1], (np.integer, np.ndarray)):
key = (key[0], key[1].tolist())

if isinstance(key[1], int):
_hit = self.find_first(
video=key[0], frame_idx=key[1], use_cache=use_cache
)
return _hit
elif isinstance(key[1], (np.integer, np.ndarray)):
return self.__getitem__((key[0], key[1].tolist()))
elif isinstance(key[1], (list, range)):
return self.find(video=key[0], frame_idx=key[1])
else:
raise KeyError("Invalid label indexing arguments.")

elif isinstance(key, slice):
start, stop, step = key.indices(len(self))
return self.__getitem__(range(start, stop, step))

elif isinstance(key, (list, range)):
return [self.__getitem__(i) for i in key]

elif isinstance(key, (np.integer, np.ndarray)):
return self.__getitem__(key.tolist())
if _hit is None:
raise KeyError(
f"No label found for specified video at frame {key[1]}."
)
return _hit
elif isinstance(key[1], (list, range)):
return self.find(video=key[0], frame_idx=key[1])
else:
raise KeyError("Invalid label indexing arguments.")

else:
raise KeyError("Invalid label indexing arguments.")
elif isinstance(key, (list, range)):
return [self.__getitem__(i) for i in key]

def get(self, *args) -> Union[LabeledFrame, List[LabeledFrame]]:
"""Get an item from the labels or return `None` if not found.
else:
raise KeyError("Invalid label indexing arguments.")

This is a safe version of `labels[...]` that will not raise an exception if the
item is not found.
"""
try:
return self.__getitem__(*args)
except KeyError:
except KeyError as e:
if raise_errors:
raise e
return None
roomrys marked this conversation as resolved.
Show resolved Hide resolved

def extract(self, inds, copy: bool = False) -> "Labels":
Expand Down Expand Up @@ -903,28 +975,35 @@ def frames(self, video: Video, from_frame_idx: int = -1, reverse=False):
yield self._cache._frame_idx_map[video][idx]

def find_first(
self, video: Video, frame_idx: Optional[int] = None
self, video: Video, frame_idx: Optional[int] = None, use_cache: bool = False
roomrys marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[LabeledFrame]:
"""Find the first occurrence of a matching labeled frame.

Matches on frames for the given video and/or frame index.

Args:
video: a `Video` instance that is associated with the
video: A `Video` instance that is associated with the
labeled frames
frame_idx: an integer specifying the frame index within
frame_idx: An integer specifying the frame index within
the video
use_cache: Boolean that determines whether Labels.find_first() should
instead instead call Labels.find() which uses the labels data cache. If
True, use the labels data cache, else loop through all labels to search.

Returns:
First `LabeledFrame` that match the criteria
or None if none were found.
"""
if video in self.videos:
for label in self.labels:
if label.video == video and (
frame_idx is None or (label.frame_idx == frame_idx)
):
return label
if use_cache:
label = self.find(video=video, frame_idx=frame_idx)
return None if len(label) == 0 else label[0]
else:
if video in self.videos:
for label in self.labels:
if label.video == video and (
frame_idx is None or (label.frame_idx == frame_idx)
):
return label

def find_last(
self, video: Video, frame_idx: Optional[int] = None
Expand Down
36 changes: 35 additions & 1 deletion tests/gui/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from sleap.gui.commands import *


def test_app_workflow(qtbot, centered_pair_vid, small_robot_mp4_vid):
def test_app_workflow(
qtbot, centered_pair_vid, small_robot_mp4_vid, min_tracks_2node_labels: Labels
):
app = MainWindow()

# Add nodes
Expand Down Expand Up @@ -208,6 +210,38 @@ def test_app_workflow(qtbot, centered_pair_vid, small_robot_mp4_vid):
assert inst_31_2.track == track_a
assert inst_31_1.track == track_b

# Set up to test labeled frames data cache
app.labels = min_tracks_2node_labels
video = app.labels.video
num_samples = 5
frame_delta = video.num_frames // num_samples

# Add suggestions
app.labels.suggestions = VideoFrameSuggestions.suggest(
labels=app.labels,
params=dict(method="sample", per_video=num_samples, sampling_method="stride"),
)
assert len(app.labels.suggestions) == num_samples

# The on_data_update function uses labeled frames cache
app.on_data_update([UpdateTopic.suggestions])
assert len(app.suggestionsTable.model().items) == num_samples
assert f"{num_samples}/{num_samples}" in app.suggested_count_label.text()

# Check that frames returned by labeled frames cache are correct
prev_idx = -frame_delta
for l_suggestion, st_suggestion in list(
zip(app.labels.get_suggestions(), app.suggestionsTable.model().items)
):
assert l_suggestion == st_suggestion["SuggestionFrame"]
lf = app.labels.get(
(l_suggestion.video, l_suggestion.frame_idx), use_cache=True
)
assert type(lf) == LabeledFrame
assert lf.video == video
assert lf.frame_idx == prev_idx + frame_delta
prev_idx = l_suggestion.frame_idx


def test_app_new_window(qtbot):
app = QApplication.instance()
Expand Down
Loading