-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added sample DLC pose estimation output (tracks) in h5 and pickle * added h5py as dependency * implemented import tracks from DLC .h5 file * implemented unit tests for loading DLC .h5 * fixed linting and styling * fix manifest issue * merged nested if statements * merged implicitly concatenated strings * catch more specific exceptions * DLC picle file holds metadata * added tables dependency * use tox-conda to install pytables in the testing environment * replaced pickle with .csv file * fixe typo in tox.ini * try removing all hdf5 related dependencies from pyproject * removed fancylog dependency * added h5py dependency * renamed tracks to poses * replaced utils with validator using pydantic * replace all remaining mentions of tracks with poses * Got rid of key argument when importing pose data from DLC h5 * Got rid of the dataframe validator * grouped tests into test class * moved validators into io module * renamed filepath to file_path * added support for loading DLC poses in .csv format * convert to Path before we run any other validators * remove remaining reference to HDF5 file * added circular logs * moved pydantic to non-dev dependencies * added conda installation instructions to README * exclude data folder from manifest * changed logging format and remove console logging * DeepLabCut written as camel case * more informative error message * removed TODO * send log file to .movement folder * configure logging during import * added optional logging to the console * importing dlc poses from h5 or csv return the same dataframe * remove unused tmp_path argument from test * conda create env with py3.11 in README.md Co-authored-by: Adam Tyson <adam.tyson@ucl.ac.uk> * reformatted long string * Save one 5MB log file instead of having 5 backups of 1MB each * changed default logging level to DEBUG * added rich as dependency * added log_directory argument to configure_logging * redirect loggin to different path during tests * added test for logging * removed console logging with rich * use pandas.testing module to asssert df equality --------- Co-authored-by: Adam Tyson <adam.tyson@ucl.ac.uk>
- Loading branch information
1 parent
29e7b81
commit 11ae2d1
Showing
15 changed files
with
1,440 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
include LICENSE | ||
include README.md | ||
|
||
recursive-exclude data * | ||
recursive-exclude * __pycache__ | ||
recursive-exclude * *.py[co] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
from importlib.metadata import PackageNotFoundError, version | ||
from movement.log_config import configure_logging | ||
|
||
try: | ||
__version__ = version("movement") | ||
except PackageNotFoundError: | ||
# package is not installed | ||
pass | ||
|
||
|
||
# initialize logger upon import | ||
configure_logging() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import logging | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import pandas as pd | ||
|
||
from movement.io.validators import DeepLabCutPosesFile | ||
|
||
# get logger | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def from_dlc(file_path: Union[Path, str]) -> Optional[pd.DataFrame]: | ||
"""Load pose estimation results from a DeepLabCut (DLC) files. | ||
Files must be in .h5 format or .csv format. | ||
Parameters | ||
---------- | ||
file_path : pathlib Path or str | ||
Path to the file containing the DLC poses. | ||
Returns | ||
------- | ||
pandas DataFrame | ||
DataFrame containing the DLC poses | ||
Examples | ||
-------- | ||
>>> from movement.io import load_poses | ||
>>> poses = load_poses.from_dlc("path/to/file.h5") | ||
""" | ||
|
||
# Validate the input file path | ||
dlc_poses_file = DeepLabCutPosesFile(file_path=file_path) # type: ignore | ||
file_suffix = dlc_poses_file.file_path.suffix | ||
|
||
# Load the DLC poses | ||
try: | ||
if file_suffix == ".csv": | ||
df = _parse_dlc_csv_to_dataframe(dlc_poses_file.file_path) | ||
else: # file can only be .h5 at this point | ||
df = pd.read_hdf(dlc_poses_file.file_path) | ||
# above line does not necessarily return a DataFrame | ||
df = pd.DataFrame(df) | ||
except (OSError, TypeError, ValueError) as e: | ||
error_msg = ( | ||
f"Could not load poses from {file_path}. " | ||
"Please check that the file is valid and readable." | ||
) | ||
logger.error(error_msg) | ||
raise OSError from e | ||
logger.info(f"Loaded poses from {file_path}") | ||
return df | ||
|
||
|
||
def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: | ||
"""If poses are loaded from a DeepLabCut.csv file, the resulting DataFrame | ||
lacks the multi-index columns that are present in the .h5 file. This | ||
function parses the csv file to a DataFrame with multi-index columns. | ||
Parameters | ||
---------- | ||
file_path : pathlib Path | ||
Path to the file containing the DLC poses, in .csv format. | ||
Returns | ||
------- | ||
pandas DataFrame | ||
DataFrame containing the DLC poses, with multi-index columns. | ||
""" | ||
|
||
possible_level_names = ["scorer", "bodyparts", "coords", "individual"] | ||
with open(file_path, "r") as f: | ||
# if line starts with a possible level name, split it into a list | ||
# of strings, and add it to the list of header lines | ||
header_lines = [ | ||
line.strip().split(",") | ||
for line in f.readlines() | ||
if line.split(",")[0] in possible_level_names | ||
] | ||
|
||
# Form multi-index column names from the header lines | ||
level_names = [line[0] for line in header_lines] | ||
column_tuples = list(zip(*[line[1:] for line in header_lines])) | ||
columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) | ||
|
||
# Import the DLC poses as a DataFrame | ||
df = pd.read_csv( | ||
file_path, skiprows=len(header_lines), index_col=0, names=columns | ||
) | ||
df.columns.rename(level_names, inplace=True) | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
from pydantic import BaseModel, validator | ||
|
||
# initialize logger | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DeepLabCutPosesFile(BaseModel): | ||
"""Pydantic class for validating files containing | ||
pose estimation results from DeepLabCut (DLC). | ||
Pydantic will enforce the input data type. | ||
This class additionally checks that the file exists | ||
and has a valid suffix. | ||
""" | ||
|
||
file_path: Path | ||
|
||
@validator("file_path", pre=True) # runs before other validators | ||
def convert_to_path(cls, value): | ||
return Path(value) | ||
|
||
@validator("file_path") | ||
def file_must_exist(cls, value): | ||
if not value.is_file(): | ||
error_msg = f"File not found: {value}" | ||
logger.error(error_msg) | ||
raise FileNotFoundError(error_msg) | ||
return value | ||
|
||
@validator("file_path") | ||
def file_must_have_valid_suffix(cls, value): | ||
if value.suffix not in (".h5", ".csv"): | ||
error_msg = ( | ||
"Expected a file with pose estimation results from " | ||
"DeepLabCut, in one of '.h5' or '.csv' formats. " | ||
f"Received a file with suffix '{value.suffix}' instead." | ||
) | ||
logger.error(error_msg) | ||
raise ValueError(error_msg) | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import logging | ||
from logging.handlers import RotatingFileHandler | ||
from pathlib import Path | ||
|
||
FORMAT = ( | ||
"%(asctime)s - %(levelname)s - " | ||
"%(processName)s %(filename)s:%(lineno)s - %(message)s" | ||
) | ||
|
||
|
||
def configure_logging( | ||
log_level: int = logging.DEBUG, | ||
logger_name: str = "movement", | ||
log_directory: Path = Path.home() / ".movement", | ||
): | ||
"""Configure the logging module. | ||
This function sets up a circular log file with a rotating file handler. | ||
Parameters | ||
---------- | ||
log_level : int, optional | ||
The logging level to use. Defaults to logging.INFO. | ||
logger_name : str, optional | ||
The name of the logger to configure. | ||
Defaults to 'movement'. | ||
log_directory : pathlib.Path, optional | ||
The directory to store the log file in. Defaults to | ||
~/.movement. A different directory can be specified, | ||
for example for testing purposes. | ||
""" | ||
|
||
# Set the log directory and file path | ||
log_directory.mkdir(parents=True, exist_ok=True) | ||
log_file = log_directory / f"{logger_name}.log" | ||
|
||
# If a logger with the given name is already configured | ||
if logger_name in logging.root.manager.loggerDict: | ||
logger = logging.getLogger(logger_name) | ||
handlers = logger.handlers[:] | ||
# If the log file path has changed | ||
if log_file.as_posix() != handlers[0].baseFilename: # type: ignore | ||
# remove the handlers to allow for reconfiguration | ||
for handler in handlers: | ||
logger.removeHandler(handler) | ||
else: | ||
# otherwise, do nothing | ||
return | ||
|
||
logger = logging.getLogger(logger_name) | ||
logger.setLevel(log_level) | ||
|
||
# Create a rotating file handler | ||
max_log_size = 5 * 1024 * 1024 # 5 MB | ||
handler = RotatingFileHandler(log_file, maxBytes=max_log_size) | ||
|
||
# Create a formatter and set it to the handler | ||
formatter = logging.Formatter(FORMAT) | ||
handler.setFormatter(formatter) | ||
|
||
# Add the handler to the logger | ||
logger.addHandler(handler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import logging | ||
|
||
import pytest | ||
|
||
from movement.log_config import configure_logging | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def setup_logging(tmp_path): | ||
"""Set up logging for the test module. | ||
Redirects all logging to a temporary directory.""" | ||
configure_logging( | ||
log_level=logging.DEBUG, | ||
logger_name="movement", | ||
log_directory=(tmp_path / ".movement"), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import h5py | ||
import pandas as pd | ||
import pytest | ||
from pandas.testing import assert_frame_equal | ||
from pydantic import ValidationError | ||
|
||
from movement.io import load_poses | ||
|
||
|
||
class TestLoadPoses: | ||
"""Test the load_poses module.""" | ||
|
||
@pytest.fixture | ||
def valid_dlc_files(self): | ||
"""Return the paths to valid DLC poses files, | ||
in .h5 format. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing the paths. | ||
- h5_path: pathlib Path to a valid .h5 file | ||
- h5_str: path as str to a valid .h5 file | ||
""" | ||
test_data_dir = Path(__file__).parent.parent.parent / "data" | ||
h5_file = test_data_dir / "DLC_sample_poses.h5" | ||
csv_file = test_data_dir / "DLC_sample_poses.csv" | ||
return { | ||
"h5_path": h5_file, | ||
"h5_str": h5_file.as_posix(), | ||
"csv_path": csv_file, | ||
"csv_str": csv_file.as_posix(), | ||
} | ||
|
||
@pytest.fixture | ||
def invalid_files(self, tmp_path): | ||
unreadable_file = tmp_path / "unreadable.h5" | ||
with open(unreadable_file, "w") as f: | ||
f.write("unreadable data") | ||
os.chmod(f.name, 0o000) | ||
|
||
wrong_ext_file = tmp_path / "wrong_extension.txt" | ||
with open(wrong_ext_file, "w") as f: | ||
f.write("") | ||
|
||
h5_file_no_dataframe = tmp_path / "no_dataframe.h5" | ||
with h5py.File(h5_file_no_dataframe, "w") as f: | ||
f.create_dataset("data_in_list", data=[1, 2, 3]) | ||
|
||
nonexistent_file = tmp_path / "nonexistent.h5" | ||
|
||
return { | ||
"unreadable": unreadable_file, | ||
"wrong_ext": wrong_ext_file, | ||
"no_dataframe": h5_file_no_dataframe, | ||
"nonexistent": nonexistent_file, | ||
} | ||
|
||
def test_load_valid_dlc_files(self, valid_dlc_files): | ||
"""Test loading valid DLC poses files.""" | ||
for file_type, file_path in valid_dlc_files.items(): | ||
df = load_poses.from_dlc(file_path) | ||
assert isinstance(df, pd.DataFrame) | ||
assert not df.empty | ||
|
||
def test_load_invalid_dlc_files(self, invalid_files): | ||
"""Test loading invalid DLC poses files.""" | ||
for file_type, file_path in invalid_files.items(): | ||
if file_type == "nonexistent": | ||
with pytest.raises(FileNotFoundError): | ||
load_poses.from_dlc(file_path) | ||
elif file_type == "wrong_ext": | ||
with pytest.raises(ValueError): | ||
load_poses.from_dlc(file_path) | ||
else: | ||
with pytest.raises(OSError): | ||
load_poses.from_dlc(file_path) | ||
|
||
@pytest.mark.parametrize("file_path", [1, 1.0, True, None, [], {}]) | ||
def test_load_from_dlc_with_incorrect_file_path_types(self, file_path): | ||
"""Test loading poses from a file_path with an incorrect type.""" | ||
with pytest.raises(ValidationError): | ||
load_poses.from_dlc(file_path) | ||
|
||
def test_load_from_dlc_csv_or_h5_file_returns_same_df( | ||
self, valid_dlc_files | ||
): | ||
"""Test that loading poses from DLC .csv and .h5 files | ||
return the same DataFrame.""" | ||
df_from_h5 = load_poses.from_dlc(valid_dlc_files["h5_path"]) | ||
df_from_csv = load_poses.from_dlc(valid_dlc_files["csv_path"]) | ||
assert_frame_equal(df_from_h5, df_from_csv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import logging | ||
|
||
import pytest | ||
|
||
log_messages = { | ||
"DEBUG": "This is a debug message", | ||
"INFO": "This is an info message", | ||
"WARNING": "This is a warning message", | ||
"ERROR": "This is an error message", | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("level, message", log_messages.items()) | ||
def test_logfile_contains_message(level, message): | ||
"""Check if the last line of the logfile contains | ||
the expected message.""" | ||
logger = logging.getLogger("movement") | ||
eval(f"logger.{level.lower()}('{message}')") | ||
log_file = logger.handlers[0].baseFilename | ||
with open(log_file, "r") as f: | ||
last_line = f.readlines()[-1] | ||
assert level in last_line | ||
assert message in last_line |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.