Skip to content

Commit

Permalink
Add Option to Export CSV (#1438)
Browse files Browse the repository at this point in the history
* Add Option to Export CSV

* Add Test Functions

* Fomat Files

* Change FormatID
  • Loading branch information
gitttt-1234 authored Aug 10, 2023
1 parent 5ba6bc1 commit 47f8096
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 23 deletions.
14 changes: 14 additions & 0 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,20 @@ def add_submenu_choices(menu, title, options, key):
lambda: self.commands.exportAnalysisFile(all_videos=True),
)

export_csv_menu = fileMenu.addMenu("Export Analysis CSV...")
add_menu_item(
export_csv_menu,
"export_csv_current",
"Current Video...",
self.commands.exportCSVFile,
)
add_menu_item(
export_csv_menu,
"export_csv_all",
"All Videos...",
lambda: self.commands.exportCSVFile(all_videos=True),
)

add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB)

fileMenu.addSeparator()
Expand Down
44 changes: 34 additions & 10 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class which inherits from `AppCommand` (or a more specialized class such as
import cv2
import attr
from qtpy import QtCore, QtWidgets, QtGui
from qtpy.QtWidgets import QMessageBox, QProgressDialog

from sleap.util import get_package_file
from sleap.skeleton import Node, Skeleton
Expand All @@ -51,6 +50,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
from sleap.io.format.csv import CSVAdaptor
from sleap.io.format.ndx_pose import NDXPoseAdaptor
from sleap.gui.dialogs.delete import DeleteDialog
from sleap.gui.dialogs.importvideos import ImportVideos
Expand Down Expand Up @@ -331,7 +331,11 @@ def saveProjectAs(self):

def exportAnalysisFile(self, all_videos: bool = False):
"""Shows gui for exporting analysis h5 file."""
self.execute(ExportAnalysisFile, all_videos=all_videos)
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False)

def exportCSVFile(self, all_videos: bool = False):
"""Shows gui for exporting analysis csv file."""
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True)

def exportNWB(self):
"""Show gui for exporting nwb file."""
Expand Down Expand Up @@ -1130,13 +1134,20 @@ class ExportAnalysisFile(AppCommand):
}
export_filter = ";;".join(export_formats.keys())

export_formats_csv = {
"CSV (*.csv)": "csv",
}
export_filter_csv = ";;".join(export_formats_csv.keys())

@classmethod
def do_action(cls, context: CommandContext, params: dict):
from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor
from sleap.io.format.nix import NixAdaptor

for output_path, video in params["analysis_videos"]:
if Path(output_path).suffix[1:] == "nix":
if params["csv"]:
adaptor = CSVAdaptor
elif Path(output_path).suffix[1:] == "nix":
adaptor = NixAdaptor
else:
adaptor = SleapAnalysisAdaptor
Expand All @@ -1149,18 +1160,24 @@ def do_action(cls, context: CommandContext, params: dict):

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
def ask_for_filename(default_name: str) -> str:
def ask_for_filename(default_name: str, csv: bool) -> str:
"""Allow user to specify the filename"""
filter = (
ExportAnalysisFile.export_filter_csv
if csv
else ExportAnalysisFile.export_filter
)
filename, selected_filter = FileDialog.save(
context.app,
caption="Export Analysis File...",
dir=default_name,
filter=ExportAnalysisFile.export_filter,
filter=filter,
)
return filename

# Ensure labels has labeled frames
labels = context.labels
is_csv = params["csv"]
if len(labels.labeled_frames) == 0:
raise ValueError("No labeled frames in project. Nothing to export.")

Expand All @@ -1178,7 +1195,7 @@ def ask_for_filename(default_name: str) -> str:
# Specify (how to get) the output filename
default_name = context.state["filename"] or "labels"
fn = PurePath(default_name)
file_extension = "h5"
file_extension = "csv" if is_csv else "h5"
if len(videos) == 1:
# Allow user to specify the filename
use_default = False
Expand All @@ -1191,18 +1208,23 @@ def ask_for_filename(default_name: str) -> str:
caption="Select Folder to Export Analysis Files...",
dir=str(fn.parent),
)
if len(ExportAnalysisFile.export_formats) > 1:
export_format = (
ExportAnalysisFile.export_formats_csv
if is_csv
else ExportAnalysisFile.export_formats
)
if len(export_format) > 1:
item, ok = QtWidgets.QInputDialog.getItem(
context.app,
"Select export format",
"Available export formats",
list(ExportAnalysisFile.export_formats.keys()),
list(export_format.keys()),
0,
False,
)
if not ok:
return False
file_extension = ExportAnalysisFile.export_formats[item]
file_extension = export_format[item]
if len(dirname) == 0:
return False

Expand All @@ -1219,7 +1241,9 @@ def ask_for_filename(default_name: str) -> str:
format_suffix=file_extension,
)

filename = default_name if use_default else ask_for_filename(default_name)
filename = (
default_name if use_default else ask_for_filename(default_name, is_csv)
)
# Check that filename is valid and create list of video / output paths
if len(filename) != 0:
analysis_videos.append(video)
Expand Down
74 changes: 72 additions & 2 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Generate an HDF5 file with track occupancy and point location data.
"""Generate an HDF5 or CSV file with track occupancy and point location data.
Ignores tracks that are entirely empty. By default will also ignore
empty frames from the beginning and end of video, although
Expand Down Expand Up @@ -29,6 +29,7 @@
import json
import h5py as h5
import numpy as np
import pandas as pd

from typing import Any, Dict, List, Tuple, Union

Expand Down Expand Up @@ -286,12 +287,77 @@ def write_occupancy_file(
print(f"Saved as {output_path}")


def write_csv_file(output_path, data_dict):

"""Write CSV file with data from given dictionary.
Args:
output_path: Path of HDF5 file.
data_dict: Dictionary with data to save. Keys are dataset names,
values are the data.
Returns:
None
"""

if data_dict["tracks"].shape[-1] == 0:
print(f"No tracks to export in {data_dict['video_path']}. Skipping the export")
return

data_dict["node_names"] = [s.decode() for s in data_dict["node_names"]]
data_dict["track_names"] = [s.decode() for s in data_dict["track_names"]]
data_dict["track_occupancy"] = np.transpose(data_dict["track_occupancy"]).astype(
bool
)

# Find frames with at least one animal tracked.
valid_frame_idxs = np.argwhere(data_dict["track_occupancy"].any(axis=1)).flatten()

tracks = []
for frame_idx in valid_frame_idxs:
frame_tracks = data_dict["tracks"][frame_idx]

for i in range(frame_tracks.shape[-1]):
pts = frame_tracks[..., i]
conf_scores = data_dict["point_scores"][frame_idx][..., i]

if np.isnan(pts).all():
# Skip if animal wasn't detected in the current frame.
continue
if data_dict["track_names"]:
track = data_dict["track_names"][i]
else:
track = None

instance_score = data_dict["instance_scores"][frame_idx][i]

detection = {
"track": track,
"frame_idx": frame_idx,
"instance.score": instance_score,
}

# Coordinates for each body part.
for node_name, score, (x, y) in zip(
data_dict["node_names"], conf_scores, pts
):
detection[f"{node_name}.x"] = x
detection[f"{node_name}.y"] = y
detection[f"{node_name}.score"] = score

tracks.append(detection)

tracks = pd.DataFrame(tracks)
tracks.to_csv(output_path, index=False)


def main(
labels: Labels,
output_path: str,
labels_path: str = None,
all_frames: bool = True,
video: Video = None,
csv: bool = False,
):
"""Writes HDF5 file with matrices of track occupancy and coordinates.
Expand All @@ -306,6 +372,7 @@ def main(
video: The :py:class:`Video` from which to get data. If no `video` is specified,
then the first video in `source_object` videos list will be used. If there
are no labeled frames in the `video`, then no output file will be written.
csv: Bool to save the analysis as a csv file if set to True
Returns:
None
Expand Down Expand Up @@ -367,7 +434,10 @@ def main(
provenance=json.dumps(labels.provenance), # dict cannot be written to hdf5.
)

write_occupancy_file(output_path, data_dict, transpose=True)
if csv:
write_csv_file(output_path, data_dict)
else:
write_occupancy_file(output_path, data_dict, transpose=True)


if __name__ == "__main__":
Expand Down
70 changes: 70 additions & 0 deletions sleap/io/format/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Adaptor for writing SLEAP analysis as csv."""

from sleap.io import format

from sleap import Labels, Video
from typing import Optional, Callable, List, Text, Union


class CSVAdaptor(format.adaptor.Adaptor):
FORMAT_ID = 1.0

# 1.0 initial implementation

@property
def handles(self):
return format.adaptor.SleapObjectType.labels

@property
def default_ext(self):
return "csv"

@property
def all_exts(self):
return ["csv", "xlsx"]

@property
def name(self):
return "CSV"

def can_read_file(self, file: format.filehandle.FileHandle):
return False

def can_write_filename(self, filename: str):
return self.does_match_ext(filename)

def does_read(self) -> bool:
return False

def does_write(self) -> bool:
return True

@classmethod
def write(
cls,
filename: str,
source_object: Labels,
source_path: str = None,
video: Video = None,
):
"""Writes csv file for :py:class:`Labels` `source_object`.
Args:
filename: The filename for the output file.
source_object: The :py:class:`Labels` from which to get data from.
source_path: Path for the labels object
video: The :py:class:`Video` from which toget data from. If no `video` is
specified, then the first video in `source_object` videos list will be
used. If there are no :py:class:`Labeled Frame`s in the `video`, then no
analysis file will be written.
"""
from sleap.info.write_tracking_h5 import main as write_analysis

write_analysis(
labels=source_object,
output_path=filename,
labels_path=source_path,
all_frames=True,
video=video,
csv=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
track,frame_idx,instance.score,A.x,A.y,A.score,B.x,B.y,B.score
,0,nan,205.9300539013689,187.88964024221963,,278.63521449272383,203.3658657346604,
8 changes: 8 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
TEST_HDF5_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
TEST_SLP_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.slp"
TEST_MIN_DANCE_LABELS = "tests/data/slp_hdf5/dance.mp4.labels.slp"
TEST_CSV_PREDICTIONS = (
"tests/data/csv_format/minimal_instance.000_centered_pair_low_quality.analysis.csv"
)


@pytest.fixture
Expand Down Expand Up @@ -247,6 +250,11 @@ def centered_pair_predictions_hdf5_path():
return TEST_HDF5_PREDICTIONS


@pytest.fixture
def minimal_instance_predictions_csv_path():
return TEST_CSV_PREDICTIONS


@pytest.fixture
def centered_pair_predictions_slp_path():
return TEST_SLP_PREDICTIONS
Expand Down
Loading

0 comments on commit 47f8096

Please sign in to comment.