Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to set background color when exporting video #1328

Merged
merged 12 commits into from
Sep 18, 2023
8 changes: 8 additions & 0 deletions sleap/config/labeled_clip_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ main:
label: Use GUI Visual Settings (colors, line widths)
type: bool
default: true
- name: show_og_video
label: Show Original Video in Background
type: bool
default: true
- name: background_color
label: Background Color
type: list
options: Black,White,Grey
shrivaths16 marked this conversation as resolved.
Show resolved Hide resolved
- name: open_when_done
label: Open When Done Saving
type: bool
Expand Down
4 changes: 4 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,8 @@ def do_action(context: CommandContext, params: dict):
frames=list(params["frames"]),
fps=params["fps"],
color_manager=params["color_manager"],
show_og_video=params["show_og_video"],
background_color=params["background_color"],
show_edges=params["show edges"],
edge_is_wedge=params["edge_is_wedge"],
marker_size=params["marker size"],
Expand Down Expand Up @@ -1330,6 +1332,8 @@ def ask(context: CommandContext, params: dict) -> bool:
params["fps"] = export_options["fps"]
params["scale"] = export_options["scale"]
params["open_when_done"] = export_options["open_when_done"]
params["show_og_video"] = export_options["show_og_video"]
params["background_color"] = export_options["background_color"]

params["crop"] = None

Expand Down
23 changes: 17 additions & 6 deletions sleap/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,17 +1091,26 @@ def is_missing(self) -> bool:
else:
return not os.path.exists(self.backend.filename)

def get_frame(self, idx: int) -> np.ndarray:
def get_frame(self, idx: int, og_video=True, background_color="Black") -> np.ndarray:
"""
Return a single frame of video from the underlying video data.

Args:
idx: The index of the video frame
og_video: whether to include og_video in the background
background_color: if not showing the original video, setting background color

Returns:
The video frame with shape (height, width, channels)
"""
return self.backend.get_frame(idx)
if og_video:
return self.backend.get_frame(idx)
if background_color == "Black":
return self.backend.get_frame(idx) * 0
if background_color == "White":
return self.backend.get_frame(idx) * 0 + 255
if background_color == "Grey":
return self.backend.get_frame(idx) * 0 + 127

def get_frames(self, idxs: Union[int, Iterable[int]]) -> np.ndarray:
"""Return a collection of video frames from the underlying video data.
Expand All @@ -1116,10 +1125,12 @@ def get_frames(self, idxs: Union[int, Iterable[int]]) -> np.ndarray:
idxs = [idxs]
return np.stack([self.get_frame(idx) for idx in idxs], axis=0)

def get_frames_safely(self, idxs: Iterable[int]) -> Tuple[List[int], np.ndarray]:
def get_frames_safely(self, idxs: Iterable[int], og_video=True, background_color="Black") -> Tuple[List[int], np.ndarray]:
"""Return list of frame indices and frames which were successfully loaded.

idxs: An iterable object that contains the indices of frames.
Args:
idxs: An iterable object that contains the indices of frames.
og_video: whether to show the original video in the background
background_color: if not showing the original video, setting background color

Returns: A tuple of (frame indices, frames), where
* frame indices is a subset of the specified idxs, and
Expand All @@ -1131,7 +1142,7 @@ def get_frames_safely(self, idxs: Iterable[int]) -> Tuple[List[int], np.ndarray]

for idx in idxs:
try:
frame = self.get_frame(idx)
frame = self.get_frame(idx, og_video=og_video, background_color=background_color)
except Exception as e:
print(e)
# ignore frames which we couldn't load
Expand Down
14 changes: 11 additions & 3 deletions sleap/io/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_sentinel = object()


def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0):
def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0, og_video=True, background_color="Black"):
"""Read frame images from video and send them into queue.

Args:
Expand All @@ -36,6 +36,8 @@ def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0):
video: The `Video` object to read.
frames: Full list frame indexes we want to read.
scale: Output scale for frame images.
og_video: Whether to include the original video in the background
background_color: if not include og_video, the background color set.

Returns:
None.
Expand All @@ -62,7 +64,9 @@ def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0):

# Safely load frames from video, skipping frames we can't load
loaded_chunk_idxs, video_frame_images = video.get_frames_safely(
frames_idx_chunk
frames_idx_chunk,
og_video=og_video,
background_color=background_color
)
shrivaths16 marked this conversation as resolved.
Show resolved Hide resolved

if not loaded_chunk_idxs:
Expand Down Expand Up @@ -497,6 +501,8 @@ def save_labeled_video(
fps: int = 15,
scale: float = 1.0,
crop_size_xy: Optional[Tuple[int, int]] = None,
show_og_video: bool = True,
background_color: str = "Black",
show_edges: bool = True,
edge_is_wedge: bool = False,
marker_size: int = 4,
Expand All @@ -515,6 +521,8 @@ def save_labeled_video(
fps: Frames per second for output video.
scale: scale of image (so we can scale point locations to match)
crop_size_xy: size of crop around instances, or None for full images
show_og_video: whether to show the original video in the background
background_color: if not showing the original video, setting background color
show_edges: whether to draw lines between nodes
edge_is_wedge: whether to draw edges as wedges (draw as line if False)
marker_size: Size of marker in pixels before scaling by `scale`
Expand All @@ -537,7 +545,7 @@ def save_labeled_video(
q2 = Queue(maxsize=10)
progress_queue = Queue()

thread_read = Thread(target=reader, args=(q1, video, frames, scale))
thread_read = Thread(target=reader, args=(q1, video, frames, scale, show_og_video, background_color))
thread_mark = VideoMarkerThread(
in_q=q1,
out_q=q2,
Expand Down