Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Option to Export CSV #1438

Merged
merged 4 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,20 @@ def add_submenu_choices(menu, title, options, key):
lambda: self.commands.exportAnalysisFile(all_videos=True),
)

export_analysis_menu = fileMenu.addMenu("Export Analysis CSV...")
add_menu_item(
export_analysis_menu,
"export_csv_current",
"Current Video...",
self.commands.exportCSV,
)
add_menu_item(
export_analysis_menu,
"export_csv_video",
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
"All Videos...",
lambda: self.commands.exportCSV(all_videos=True),
)

add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB)

fileMenu.addSeparator()
Expand Down
107 changes: 107 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
from sleap.io.format.ndx_pose import NDXPoseAdaptor
from sleap.io.format.csv import CSVAdaptor
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
from sleap.gui.dialogs.delete import DeleteDialog
from sleap.gui.dialogs.importvideos import ImportVideos
from sleap.gui.dialogs.filedialog import FileDialog
Expand Down Expand Up @@ -337,6 +338,10 @@ def exportNWB(self):
"""Show gui for exporting nwb file."""
self.execute(SaveProjectAs, adaptor=NDXPoseAdaptor())

def exportCSV(self, all_videos:bool = False):
"""Show gui for exporting csv file."""
self.execute(ExportCSVFile, all_videos=all_videos)

def exportLabeledClip(self):
"""Shows gui for exporting clip with visual annotations."""
self.execute(ExportLabeledClip)
Expand Down Expand Up @@ -1122,6 +1127,108 @@ def ask(context: CommandContext, params: dict) -> bool:
params["filename"] = filename
return True

class ExportCSVFile(AppCommand):
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
export_formats = {
"CSV (*.csv)": "csv",
"Excel Worksheet (*.xlsx)": "xlsx"
}
export_filter = ";;".join(export_formats.keys())

@classmethod
def do_action(cls, context: CommandContext, params: dict):
for output_path, video in params["analysis_videos"]:
adaptor = CSVAdaptor
adaptor.write(
filename=output_path,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
def ask_for_filename(default_name: str) -> str:
"""Allow user to specify the filename"""
filename, selected_filter = FileDialog.save(
context.app,
caption="Export CSV File...",
dir=default_name,
filter=ExportCSVFile.export_filter,
)
return filename

# Ensure labels has labeled frames
labels = context.labels
if len(labels.labeled_frames) == 0:
raise ValueError("No labeled frames in project. Nothing to export.")

# Get a subset of videos
if params["all_videos"]:
all_videos = context.labels.videos
else:
all_videos = [context.state["video"] or context.labels.videos[0]]

# Only use videos with labeled frames
videos = [video for video in all_videos if len(labels.get(video)) != 0]
if len(videos) == 0:
raise ValueError("No labeled frames in video(s). Nothing to export.")

# Specify (how to get) the output filename
default_name = context.state["filename"] or "labels"
fn = PurePath(default_name)
file_extension = "csv"
if len(videos) == 1:
# Allow user to specify the filename
use_default = False
dirname = str(fn.parent)
else:
# Allow user to specify directory, but use default filenames
use_default = True
dirname = FileDialog.openDir(
context.app,
caption="Select Folder to Export csv Files...",
dir=str(fn.parent),
)
if len(ExportAnalysisFile.export_formats) > 1:
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
item, ok = QtWidgets.QInputDialog.getItem(
context.app,
"Select export format",
"Available export formats",
list(ExportCSVFile.export_formats.keys()),
0,
False,
)
if not ok:
return False
file_extension = ExportCSVFile.export_formats[item]
if len(dirname) == 0:
return False

# Create list of video / output paths
output_paths = []
analysis_videos = []
for video in videos:
# Create the filename
default_name = default_analysis_filename(
labels=labels,
video=video,
output_path=dirname,
output_prefix=str(fn.stem),
format_suffix=file_extension,
)

filename = default_name if use_default else ask_for_filename(default_name)
# Check that filename is valid and create list of video / output paths
if len(filename) != 0:
analysis_videos.append(video)
output_paths.append(filename)

# Chack that output paths are valid
if len(output_paths) == 0:
return False

params["analysis_videos"] = zip(output_paths, videos)
return True

class ExportAnalysisFile(AppCommand):
export_formats = {
Expand Down
61 changes: 60 additions & 1 deletion sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import json
import h5py as h5
import numpy as np
import pandas as pd

from typing import Any, Dict, List, Tuple, Union

Expand Down Expand Up @@ -286,12 +287,66 @@ def write_occupancy_file(
print(f"Saved as {output_path}")


def write_csv_file(output_path, data_dict):

"""Write CSV file with data from given dictionary.

Args:
output_path: Path of HDF5 file.
data_dict: Dictionary with data to save. Keys are dataset names,
values are the data.

Returns:
None
"""

data_dict["node_names"] = [s.decode() for s in data_dict["node_names"]]
data_dict["track_names"] = [s.decode() for s in data_dict["track_names"]]

data_dict["track_occupancy"] = np.array(data_dict["track_occupancy"]).astype(bool)

# find frames with at least one animal tracked
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
valid_frame_idxs = np.argwhere(data_dict["track_occupancy"].any(axis=1)).flatten()

tracks = []
for frame_idx in valid_frame_idxs:
# Tracking data for the current frame.
frame_tracks = data_dict["tracks"][frame_idx]

# Loop over the animals in the current frame.
for i in range(frame_tracks.shape[-1]):
pts = frame_tracks[..., i]

if np.isnan(pts).all():
# Skip this animal if all of its points are missing (i.e., it wasn't
# detected in the current frame).
continue

# Initialize row with some metadata.
if data_dict["track_names"]:
track = data_dict["track_names"][i]
else:
track = None
detection = {"track": track, "frame_idx": frame_idx}

# Coordinates for each body part.
for node_name, (x, y) in zip(data_dict["node_names"], pts):
detection[f"{node_name}.x"] = x
detection[f"{node_name}.y"] = y

tracks.append(detection)

tracks = pd.DataFrame(tracks)
tracks.to_csv(output_path, index=False)


def main(
labels: Labels,
output_path: str,
labels_path: str = None,
all_frames: bool = True,
video: Video = None,
csv: bool = False,
):
"""Writes HDF5 file with matrices of track occupancy and coordinates.

Expand All @@ -306,6 +361,7 @@ def main(
video: The :py:class:`Video` from which to get data. If no `video` is specified,
then the first video in `source_object` videos list will be used. If there
are no labeled frames in the `video`, then no output file will be written.
csv: Bool to save the analysis as a csv file if set to True

Returns:
None
Expand Down Expand Up @@ -367,7 +423,10 @@ def main(
provenance=json.dumps(labels.provenance), # dict cannot be written to hdf5.
)

write_occupancy_file(output_path, data_dict, transpose=True)
if csv:
write_csv_file(output_path, data_dict)
else:
write_occupancy_file(output_path, data_dict, transpose=True)


if __name__ == "__main__":
Expand Down
75 changes: 75 additions & 0 deletions sleap/io/format/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Adaptor for writing SLEAP analysis as csv.

"""
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

from sleap.io import format

from sleap import Labels, Video
from typing import Optional, Callable, List, Text, Union


class CSVAdaptor(format.adaptor.Adaptor):
FORMAT_ID = 1.2

# 1.0 points with gridline coordinates, top left corner at (0, 0)
# 1.1 points with midpixel coordinates, top left corner at (-0.5, -0.5)
# 1.2 adds track score to read and write functions
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

@property
def handles(self):
return format.adaptor.SleapObjectType.labels

@property
def default_ext(self):
return "csv"

@property
def all_exts(self):
return ["csv", "xlsx"]

@property
def name(self):
return "CSV"

def can_read_file(self, file: format.filehandle.FileHandle):
return False

def can_write_filename(self, filename: str):
return self.does_match_ext(filename)

def does_read(self) -> bool:
return False

def does_write(self) -> bool:
return True

@classmethod
def write(
cls,
filename: str,
source_object: Labels,
source_path: str = None,
video: Video = None,
):
"""Writes csv file for :py:class:`Labels` `source_object`.

Args:
filename: The filename for the output file.
source_object: The :py:class:`Labels` from which to get data from.
source_path: Path for the labels object
video: The :py:class:`Video` from which toget data from. If no `video` is
specified, then the first video in `source_object` videos list will be
used. If there are no :py:class:`Labeled Frame`s in the `video`, then no
analysis file will be written.
"""
from sleap.info.write_tracking_h5 import main as write_analysis

write_analysis(
labels=source_object,
output_path=filename,
labels_path=source_path,
all_frames=True,
video=video,
csv=True
)
16 changes: 16 additions & 0 deletions tests/io/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path, PurePath

import numpy as np
import pandas as pd
from numpy.testing import assert_array_equal
import pytest
import nixio
Expand Down Expand Up @@ -125,6 +126,21 @@ def test_hdf5_v1_filehandle(centered_pair_predictions_hdf5_path):
== "tests/data/json_format_v1/centered_pair_low_quality.mp4"
)

def test_csv(tmpdir, centered_pair_predictions):
from sleap.info.write_tracking_h5 import main as write_analysis

filename = os.path.join(tmpdir, "analysis.csv")
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
video = centered_pair_predictions.videos[0]

write_analysis(centered_pair_predictions, output_path=filename, all_frames=True, csv=True)

labels = pd.read_csv(
filename, header=True
)

# TODO: assert

gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved


def test_analysis_hdf5(tmpdir, centered_pair_predictions):
from sleap.info.write_tracking_h5 import main as write_analysis
Expand Down
Loading