Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qt widget for loading pose datasets as napari Points layers #253

Open
wants to merge 24 commits into
base: napari-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
26e99df
initialise napari plugin development
niksirbi Jun 12, 2024
edfe95e
initialise napari plugin development
niksirbi Jun 12, 2024
664ea08
Added loader widget for poses
niksirbi Jul 30, 2024
274ff59
update widget tests
niksirbi Jul 30, 2024
02b3cce
simplify dependency on brainglobe-utils
niksirbi Sep 2, 2024
d8ba49e
consistent monospace formatting for movement in public docstrings
niksirbi Sep 2, 2024
3fd8f9a
get rid of code that's only relevant for displaying Tracks
niksirbi Sep 2, 2024
8cb0c1e
enable visibility of napari layer tooltips
niksirbi Sep 2, 2024
2b842cc
renamed widget to PosesLoader
niksirbi Sep 2, 2024
38cdf5d
make cmap optional in set_color_by method
niksirbi Sep 3, 2024
366c76c
wrote unit tests for napari convert module
niksirbi Sep 3, 2024
33cbdc1
wrote unit-tests for the layer styles module
niksirbi Sep 12, 2024
719263c
linkcheck ignore zenodo redirects
niksirbi Sep 12, 2024
52dfd00
move _sample_colormap out of PointsStyle class
niksirbi Sep 13, 2024
06fd5fb
small refactoring in the loader widget
niksirbi Sep 13, 2024
94addfc
Expand tests for loader widget
niksirbi Sep 13, 2024
72fe058
added comments and docstrings to napari plugin tests
niksirbi Sep 16, 2024
6d6fe71
refactored all napari tests into separate unit test folder
niksirbi Sep 16, 2024
d9fd240
added napari-video to dependencies
niksirbi Sep 16, 2024
9a487b1
replaced deprecated edge_width with border_width
niksirbi Sep 16, 2024
a36f742
got rid of widget pytest fixtures
niksirbi Sep 16, 2024
fc7564f
remove duplicate word from docstring
niksirbi Sep 16, 2024
e15a613
remove napari-video dependency
niksirbi Oct 4, 2024
254c5f9
include napari extras in docs requirements
niksirbi Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .
-e .[napari]
linkify-it-py
myst-parser
nbsphinx
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"https://pubs.acs.org/doi/*", # Checking dois is forbidden here
]


myst_url_schemes = {
"http": None,
"https": None,
Expand Down
32 changes: 0 additions & 32 deletions movement/napari/_loader_widget.py

This file was deleted.

152 changes: 152 additions & 0 deletions movement/napari/_loader_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Widgets for loading movement datasets from file."""

import logging
from pathlib import Path

from napari.settings import get_settings
from napari.utils.notifications import show_warning
from napari.viewer import Viewer
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QFormLayout,
QHBoxLayout,
QLineEdit,
QPushButton,
QSpinBox,
QWidget,
)

from movement.io import load_poses
from movement.napari.convert import poses_to_napari_tracks
from movement.napari.layer_styles import PointsStyle

logger = logging.getLogger(__name__)

# Allowed poses file suffixes for each supported source software
SUPPORTED_POSES_FILES = {
"DeepLabCut": ["*.h5", "*.csv"],
"LightningPose": ["*.csv"],
"SLEAP": ["*.h5", "*.slp"],
}


class PosesLoader(QWidget):
"""Widget for loading movement poses datasets from file."""

def __init__(self, napari_viewer: Viewer, parent=None):
"""Initialize the loader widget."""
super().__init__(parent=parent)
self.viewer = napari_viewer
self.setLayout(QFormLayout())
# Create widgets
self._create_source_software_widget()
self._create_fps_widget()
self._create_file_path_widget()
self._create_load_button()
# Enable layer tooltips from napari settings
self._enable_layer_tooltips()

def _create_source_software_widget(self):
"""Create a combo box for selecting the source software."""
self.source_software_combo = QComboBox()
self.source_software_combo.addItems(SUPPORTED_POSES_FILES.keys())
self.layout().addRow("source software:", self.source_software_combo)

def _create_fps_widget(self):
"""Create a spinbox for selecting the frames per second (fps)."""
self.fps_spinbox = QSpinBox()
self.fps_spinbox.setMinimum(1)
self.fps_spinbox.setMaximum(1000)
self.fps_spinbox.setValue(30)
self.layout().addRow("fps:", self.fps_spinbox)

def _create_file_path_widget(self):
"""Create a line edit and browse button for selecting the file path.

This allows the user to either browse the file system,
or type the path directly into the line edit.
"""
# File path line edit and browse button
self.file_path_edit = QLineEdit()
self.browse_button = QPushButton("Browse")
self.browse_button.clicked.connect(self._on_browse_clicked)
# Layout for line edit and button
self.file_path_layout = QHBoxLayout()
self.file_path_layout.addWidget(self.file_path_edit)
self.file_path_layout.addWidget(self.browse_button)
self.layout().addRow("file path:", self.file_path_layout)

def _create_load_button(self):
"""Create a button to load the file and add layers to the viewer."""
self.load_button = QPushButton("Load")
self.load_button.clicked.connect(lambda: self._on_load_clicked())
self.layout().addRow(self.load_button)

def _on_browse_clicked(self):
"""Open a file dialog to select a file."""
file_suffixes = SUPPORTED_POSES_FILES[

Check warning on line 88 in movement/napari/_loader_widgets.py

View check run for this annotation

Codecov / codecov/patch

movement/napari/_loader_widgets.py#L88

Added line #L88 was not covered by tests
self.source_software_combo.currentText()
]
dlg = QFileDialog()
dlg.setFileMode(QFileDialog.ExistingFile)
dlg.setNameFilter(

Check warning on line 93 in movement/napari/_loader_widgets.py

View check run for this annotation

Codecov / codecov/patch

movement/napari/_loader_widgets.py#L91-L93

Added lines #L91 - L93 were not covered by tests
f"Files containing predicted poses ({' '.join(file_suffixes)})"
)
if dlg.exec_():
file_paths = dlg.selectedFiles()

Check warning on line 97 in movement/napari/_loader_widgets.py

View check run for this annotation

Codecov / codecov/patch

movement/napari/_loader_widgets.py#L96-L97

Added lines #L96 - L97 were not covered by tests
# Set the file path in the line edit
self.file_path_edit.setText(file_paths[0])

Check warning on line 99 in movement/napari/_loader_widgets.py

View check run for this annotation

Codecov / codecov/patch

movement/napari/_loader_widgets.py#L99

Added line #L99 was not covered by tests

def _on_load_clicked(self):
"""Load the file and add as a Points layer to the viewer."""
fps = self.fps_spinbox.value()
source_software = self.source_software_combo.currentText()
file_path = self.file_path_edit.text()
if file_path == "":
show_warning("No file path specified.")
return
ds = load_poses.from_file(file_path, source_software, fps)

self.data, self.props = poses_to_napari_tracks(ds)
logger.info("Converted poses dataset to a napari Tracks array.")
logger.debug(f"Tracks array shape: {self.data.shape}")

self.file_name = Path(file_path).name
self._add_points_layer()

self._set_playback_fps(fps)
logger.debug(f"Set napari playback speed to {fps} fps.")

def _add_points_layer(self):
"""Add the predicted poses to the viewer as a Points layer."""
# Style properties for the napari Points layer
points_style = PointsStyle(
name=f"poses: {self.file_name}",
properties=self.props,
)
# Color the points by individual if there are multiple individuals
# Otherwise, color by keypoint
n_individuals = len(self.props["individual"].unique())
points_style.set_color_by(
prop="individual" if n_individuals > 1 else "keypoint"
)
# Add the points layer to the viewer
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())
logger.info("Added poses dataset as a napari Points layer.")

@staticmethod
def _set_playback_fps(fps: int):
"""Set the playback speed for the napari viewer."""
settings = get_settings()
settings.application.playback_fps = fps

@staticmethod
def _enable_layer_tooltips():
"""Toggle on tooltip visibility for napari layers.

This nicely displays the layer properties as a tooltip
when hovering over the layer in the napari viewer.
"""
settings = get_settings()
settings.appearance.layer_tooltip_visibility = True
6 changes: 3 additions & 3 deletions movement/napari/_meta_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
from napari.viewer import Viewer

from movement.napari._loader_widget import Loader
from movement.napari._loader_widgets import PosesLoader


class MovementMetaWidget(CollapsibleWidgetContainer):
Expand All @@ -18,9 +18,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__()

self.add_widget(
Loader(napari_viewer, parent=self),
PosesLoader(napari_viewer, parent=self),
collapsible=True,
widget_title="Load data",
widget_title="Load poses",
)

self.loader = self.collapsible_widgets[0]
Expand Down
81 changes: 81 additions & 0 deletions movement/napari/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Conversion functions from ``movement`` datasets to napari layers."""

import logging

import numpy as np
import pandas as pd
import xarray as xr

# get logger
logger = logging.getLogger(__name__)


def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame:
"""Construct a properties DataFrame from a ``movement`` dataset."""
return pd.DataFrame(
{
"individual": ds.coords["individuals"].values,
"keypoint": ds.coords["keypoints"].values,
"time": ds.coords["time"].values,
"confidence": ds["confidence"].values.flatten(),
}
)


def poses_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]:
"""Convert poses dataset to napari Tracks array and properties.

Parameters
----------
ds : xr.Dataset
``movement`` dataset containing pose tracks, confidence scores,
and associated metadata.

Returns
-------
data : np.ndarray
napari Tracks array with shape (N, 4),
where N is n_keypoints * n_individuals * n_frames
and the 4 columns are (track_id, frame_idx, y, x).
properties : pd.DataFrame
DataFrame with properties (individual, keypoint, time, confidence).

Notes
-----
A corresponding napari Points array can be derived from the Tracks array
by taking its last 3 columns: (frame_idx, y, x). See the documentation
on the napari Tracks [1]_ and Points [2]_ layers.

References
----------
.. [1] https://napari.org/stable/howtos/layers/tracks.html
.. [2] https://napari.org/stable/howtos/layers/points.html

"""
ds_ = ds.copy() # make a copy to avoid modifying the original dataset

n_frames = ds_.sizes["time"]
n_individuals = ds_.sizes["individuals"]
n_keypoints = ds_.sizes["keypoints"]
n_tracks = n_individuals * n_keypoints

# assign unique integer ids to individuals and keypoints
ds_.coords["individual_ids"] = ("individuals", range(n_individuals))
ds_.coords["keypoint_ids"] = ("keypoints", range(n_keypoints))

# Stack 3 dimensions into a new single dimension named "tracks"
ds_ = ds_.stack(tracks=("individuals", "keypoints", "time"))
# Track ids are unique ints (individual_id * n_keypoints + keypoint_id)
individual_ids = ds_.coords["individual_ids"].values
keypoint_ids = ds_.coords["keypoint_ids"].values
track_ids = (individual_ids * n_keypoints + keypoint_ids).reshape(-1, 1)

# Construct the napari Tracks array
yx_columns = np.fliplr(ds_["position"].values.T)
time_column = np.tile(range(n_frames), n_tracks).reshape(-1, 1)
data = np.hstack((track_ids, time_column, yx_columns))

# Construct the properties DataFrame
properties = _construct_properties_dataframe(ds_)

return data, properties
68 changes: 68 additions & 0 deletions movement/napari/layer_styles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Dataclasses containing layer styles for napari."""

from dataclasses import dataclass, field

import numpy as np
import pandas as pd
from napari.utils.colormaps import ensure_colormap

DEFAULT_COLORMAP = "turbo"


@dataclass
class LayerStyle:
"""Base class for napari layer styles."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"

def as_kwargs(self) -> dict:
"""Return the style properties as a dictionary of kwargs."""
return self.__dict__


@dataclass
class PointsStyle(LayerStyle):
"""Style properties for a napari Points layer."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"
symbol: str = "disc"
size: int = 10
border_width: int = 0
face_color: str | None = None
face_color_cycle: list[tuple] | None = None
face_colormap: str = DEFAULT_COLORMAP
text: dict = field(default_factory=lambda: {"visible": False})

def set_color_by(self, prop: str, cmap: str | None = None) -> None:
"""Set the face_color to a column in the properties DataFrame.

Parameters
----------
prop : str
The column name in the properties DataFrame to color by.
cmap : str, optional
The name of the colormap to use, otherwise use the face_colormap.

"""
if cmap is None:
cmap = self.face_colormap
self.face_color = prop
self.text["string"] = prop
n_colors = len(self.properties[prop].unique())
self.face_color_cycle = _sample_colormap(n_colors, cmap)


def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
"""Sample n equally-spaced colors from a napari colormap.

This includes the endpoints of the colormap.
"""
cmap = ensure_colormap(cmap_name)
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
return [tuple(cmap.colors[i]) for i in samples]
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ entry-points."napari.manifest".movement = "movement.napari:napari.yaml"

[project.optional-dependencies]
napari = [
"napari[all]>=0.4.19",
# the rest will be replaced by brainglobe-utils[qt]>=0.6 after release
"brainglobe-atlasapi>=2.0.7",
"brainglobe-utils>=0.5",
"qtpy",
"superqt",
"napari[all]>=0.5.0",
"brainglobe-utils[qt]>=0.6" # needed for collapsible widgets
]
dev = [
"pytest",
Expand Down
Loading
Loading