-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for more .nwb training data to SLEAP #104
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe recent updates enhance interoperability between SLEAP and NWB formats by introducing new functions for handling pose training and skeleton data. These changes improve data management, incorporate better error handling, and refine the loading and saving processes, ensuring seamless integration for users working with labeled datasets. Changes
Possibly related issues
Possibly related PRs
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Outside diff range, codebase verification and nitpick comments (1)
sleap_io/io/main.py (1)
50-66
: Ensureload_nwb
function calls match the new signatureThe
load_nwb
function calls intests/io/test_main.py
do not match the new signature. Theas_training
parameter is missing in these calls.
tests/io/test_main.py
, lines whereload_nwb
is called:
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
(twice)Please update these calls to include the
as_training
parameter.Analysis chain
LGTM! But verify the function usage in the codebase.
The code changes are approved.
However, ensure that all function calls to
load_nwb
match the new signature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `load_nwb` match the new signature. # Test: Search for the function usage. Expect: Only occurances of the new signature. rg --type python -A 5 $'load_nwb'Length of output: 63
Script:
#!/bin/bash # Description: Verify all function calls to `load_nwb` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --type py -A 5 $'load_nwb'Length of output: 1954
Script:
#!/bin/bash # Description: Verify that the function calls to `load_nwb` in tests/io/test_main.py match the new signature. # Test: Extract lines of code where `load_nwb` is called in tests/io/test_main.py. rg 'load_nwb' tests/io/test_main.py -A 5 -B 5Length of output: 908
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (3)
- pyproject.toml (1 hunks)
- sleap_io/io/main.py (3 hunks)
- sleap_io/io/nwb.py (9 hunks)
Files skipped from review due to trivial changes (1)
- pyproject.toml
Additional comments not posted (7)
sleap_io/io/main.py (2)
69-93
: Verify the logic for handlingas_training
argument.The logic for handling the
as_training
argument is correct and well-implemented. The function correctly differentiates between saving as a training dataset and saving as a regular dataset based on the value ofas_training
.
69-93
: LGTM! But verify the function usage in the codebase.The code changes are approved.
However, ensure that all function calls to
save_nwb
match the new signature.Verification successful
Function calls to
save_nwb
match the new signature.The function calls to
save_nwb
in the following files match the new signature:
sleap_io/io/main.py
tests/io/test_main.py
No further issues were found.
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify all function calls to `save_nwb` match the new signature. # Test: Search for the function usage. Expect: Only occurances of the new signature. rg --type python -A 5 $'save_nwb'Length of output: 63
Script:
#!/bin/bash # Description: Verify all function calls to `save_nwb` match the new signature. # Test: Search for the function usage. Expect: Only occurrences of the new signature. rg --glob '*.py' -A 5 $'save_nwb'Length of output: 2089
Script:
#!/bin/bash # Description: Extract lines of code where `save_nwb` is used to verify the function signature. # Extract lines from sleap_io/__init__.py rg --glob 'sleap_io/__init__.py' -A 5 $'save_nwb' # Extract lines from sleap_io/io/main.py rg --glob 'sleap_io/io/main.py' -A 5 $'save_nwb' # Extract lines from tests/io/test_main.py rg --glob 'tests/io/test_main.py' -A 5 $'save_nwb'Length of output: 2191
sleap_io/io/nwb.py (5)
92-144
: LGTM!The function correctly converts a SLEAP
Labels
object to an NWBPoseTraining
object.
147-169
: LGTM!The function correctly converts a SLEAP
Skeleton
to an NWBSkeleton
.
172-191
: LGTM!The function correctly converts a SLEAP
Instance
to an NWBSkeletonInstance
.
194-216
: LGTM!The function correctly converts a list of SLEAP
Videos
to NWBSourceVideos
.
394-411
: LGTM!The function correctly handles the writing of
PoseTraining
data to an NWB file, including metadata handling.
sleap_io/io/main.py
Outdated
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels: | ||
"""Load an NWB dataset as a SLEAP `Labels` object. | ||
|
||
Args: | ||
filename: Path to a NWB file (`.nwb`). | ||
as_training: If `True`, load the dataset as a training dataset. | ||
|
||
Returns: | ||
The dataset as a `Labels` object. | ||
""" | ||
return nwb.read_nwb(filename) | ||
if as_training is None: | ||
return | ||
|
||
if as_training: | ||
return nwb.read_nwb_training(filename) | ||
else: | ||
return nwb.read_nwb(filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify the logic for handling as_training
argument.
The logic for handling the as_training
argument is correct, but the initial check for as_training is None
is redundant. The function should directly check for the value of as_training
without the initial if
condition.
- if as_training is None:
- return
-
if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels: | |
"""Load an NWB dataset as a SLEAP `Labels` object. | |
Args: | |
filename: Path to a NWB file (`.nwb`). | |
as_training: If `True`, load the dataset as a training dataset. | |
Returns: | |
The dataset as a `Labels` object. | |
""" | |
return nwb.read_nwb(filename) | |
if as_training is None: | |
return | |
if as_training: | |
return nwb.read_nwb_training(filename) | |
else: | |
return nwb.read_nwb(filename) | |
if as_training: | |
return nwb.read_nwb_training(filename) | |
else: | |
return nwb.read_nwb(filename) |
sleap_io/io/nwb.py
Outdated
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] | ||
"""Creates a Labels object from an NWB PoseTraining object. | ||
|
||
Args: | ||
pose_training: An NWB PoseTraining object. | ||
|
||
Returns: | ||
A Labels object. | ||
""" | ||
labeled_frames = [] | ||
for training_frame in pose_training.training_frames: | ||
video = Video(filename=f"{training_frame.source_videos}") | ||
frame_idx = training_frame # TODO | ||
instances = [ | ||
Instance.from_numpy( | ||
points=instance.node_locations, | ||
skeleton=nwb_skeleton_to_sleap(instance.skeleton), | ||
) | ||
for instance in training_frame.skeleton_instances | ||
] | ||
labeled_frames.append( | ||
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) | ||
) | ||
return Labels(labeled_frames=labeled_frames) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incomplete functionality in pose_training_to_labels
.
The function contains a TODO comment indicating incomplete functionality for determining frame_idx
from training_frame
.
- frame_idx = training_frame # TODO
+ frame_idx = training_frame.frame_index # Assuming `frame_index` is an attribute of `TrainingFrame`
Do you want me to assist in completing this functionality or open a GitHub issue to track this task?
Committable suggestion was skipped due to low confidence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (9 hunks)
Additional comments not posted (8)
sleap_io/io/nwb.py (8)
18-19
: Update imports to includeImageSeries
.The addition of
ImageSeries
frompynwb.image
is appropriate given the new functionalities related to handling image data. This aligns with the PR's objective of enhancing NWB data handling.
39-43
: Review the renaming ofSkeleton
toSLEAPSkeleton
.Renaming
Skeleton
toSLEAPSkeleton
helps avoid naming conflicts withSkeleton
fromndx_pose
. This is a good practice to prevent ambiguity in large codebases.
172-191
: Review the functioninstance_to_skeleton_instance
.Converts a SLEAP Instance to an NWB SkeletonInstance. The function correctly handles the conversion of points to node locations and ensures the visibility of nodes is maintained. The use of
name_generator
for the instance name and the setting of an arbitrary ID (np.uint(10)
) should be reviewed for appropriateness and potential conflicts.Consider verifying the use of
name_generator
and the rationale behind the fixed ID setting.
48-71
: Review the functionpose_training_to_labels
.This function converts
PoseTraining
objects toLabels
. The loop iterates overtraining_frames
, creatingLabeledFrame
instances. Each instance is created using data fromtraining_frame
, including video filename and frame index. The use ofInstance.from_numpy
suggests that the conversion handles the data appropriately.However, ensure that the
source_video_frame_index
and other attributes are correctly mapped and exist in the NWB structure as expected.Verification successful
The attributes
source_video_frame_index
,source_video
, andskeleton_instances
are correctly mapped and exist in the NWB structure as expected.
- The attributes are used appropriately within the NWB and SLEAP structures, ensuring the conversion in
pose_training_to_labels
is handled correctly.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the existence and correct mapping of attributes in NWB structure. # Test: Search for the usage of `source_video_frame_index` and related attributes in the codebase. Expect: Proper usage and definitions. rg --type python -A 5 $'source_video_frame_index'Length of output: 79
Script:
#!/bin/bash # Description: Verify the existence and correct mapping of attributes in NWB structure. # Test: Search for the usage of `source_video_frame_index` and related attributes in the codebase. Expect: Proper usage and definitions. rg --type py -A 5 'source_video_frame_index' rg --type py -A 5 'source_video' rg --type py -A 5 'skeleton_instances'Length of output: 12535
74-89
: Review the functionnwb_skeleton_to_sleap
.The function converts an NWB skeleton to a SLEAP skeleton. The conversion of nodes and edges is straightforward and uses list comprehensions. The function appears to handle the data correctly according to the provided structure and types.
Ensure that the
nodes
andedges
attributes are correctly utilized in the NWB and SLEAP data structures.Verification successful
The
nodes
andedges
attributes are correctly utilized in the NWB and SLEAP data structures.The search results confirm extensive usage and testing of these attributes throughout the codebase, indicating that the function
nwb_skeleton_to_sleap
handles them correctly.
- Verified in:
tests/io/test_nwb.py
tests/model/test_skeleton.py
sleap_io/io/nwb.py
sleap_io/io/slp.py
sleap_io/model/skeleton.py
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the correct utilization of `nodes` and `edges` attributes in the data structures. # Test: Search for the usage of `nodes` and `edges` in the codebase. Expect: Correct utilization in NWB and SLEAP structures. rg --type python -A 5 $'nodes|edges'Length of output: 68
Script:
#!/bin/bash # Description: Verify the correct utilization of `nodes` and `edges` attributes in the data structures. # Test: Search for the usage of `nodes` and `edges` in the codebase. Expect: Correct utilization in NWB and SLEAP structures. rg -t py -A 5 'nodes|edges'Length of output: 44681
92-144
: Review the functionlabels_to_pose_training
.This function creates an NWB
PoseTraining
object from aLabels
object. It iterates over labeled frames, extracting instances and converting them usinginstance_to_skeleton_instance
. The creation ofTrainingFrame
andSourceVideos
is handled well, with attention to detail in metadata and video handling.However, the function uses a custom
name_generator
which might introduce non-determinism or conflicts. Verify the implementation of this generator to ensure it produces unique and appropriate names without conflicts.
194-216
: Review the functionvideos_to_source_videos
.This function converts a list of SLEAP Videos to NWB SourceVideos. The handling of video metadata and the creation of
ImageSeries
objects are done correctly. The function ensures that each video is represented as anImageSeries
, preserving essential metadata like dimensions and rate.Ensure that the video conversion process does not lose any critical information and that the
ImageSeries
objects are correctly utilized in the NWB structure.Verification successful
Review the function
videos_to_source_videos
.The function converts a list of SLEAP Videos to NWB SourceVideos, ensuring that each video is represented as an
ImageSeries
with essential metadata like dimensions and rate. The presence of test cases intests/io/test_nwb.py
that verify thedimensions
andrate
attributes confirms that the video conversion process is correctly handled and no critical information is lost.
- Relevant Test Cases:
tests/io/test_nwb.py
: Contains tests verifyingdimensions
,rate
, and other related attributes.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the correct utilization and information preservation in video conversion. # Test: Search for the usage of `ImageSeries` and related attributes in the codebase. Expect: Correct utilization and information preservation. rg --type python -A 5 $'ImageSeries|dimension|rate'Length of output: 83
Script:
#!/bin/bash # Description: Verify the correct utilization and information preservation in video conversion. # Test: Search for the usage of `ImageSeries` and related attributes in the codebase. Expect: Correct utilization and information preservation. rg -A 5 'ImageSeries|dimension|rate'Length of output: 29675
147-169
: Review the functionslp_skeleton_to_nwb
.Converts a SLEAP skeleton to an NWB skeleton. The function handles the conversion of nodes and edges, ensuring that the data is correctly transformed. The use of a dictionary comprehension for
skeleton_edges
and the subsequent edge conversion is efficient.Ensure that the conversion logic correctly preserves the relationships and data integrity between the original and converted structures.
Verification successful
Verification successful for
slp_skeleton_to_nwb
function.The
slp_skeleton_to_nwb
function is correctly integrated into theinstance_to_skeleton_instance
function, ensuring that the relationships and data integrity are preserved during the conversion process.
sleap_io/io/nwb.py
: Verified usage and context ofslp_skeleton_to_nwb
withininstance_to_skeleton_instance
.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the data integrity and relationship preservation in skeleton conversion. # Test: Search for the usage and results of skeleton conversion in the codebase. Expect: Correct data integrity and relationship preservation. rg --type python -A 5 $'skeleton_edges|nwb_edges'Length of output: 81
Script:
#!/bin/bash # Description: Verify the data integrity and relationship preservation in skeleton conversion. # Test: Search for the usage and results of skeleton conversion in the codebase. Expect: Correct data integrity and relationship preservation. rg -A 5 'skeleton_edges|nwb_edges'Length of output: 1120
Script:
#!/bin/bash # Description: Verify the data integrity and relationship preservation in skeleton conversion. # Test: Search for the usage of the slp_skeleton_to_nwb function in the codebase. Expect: Correct data integrity and relationship preservation. rg 'slp_skeleton_to_nwb'Length of output: 205
Script:
#!/bin/bash # Description: Verify the data integrity and relationship preservation in skeleton conversion. # Test: Search for the context around the call to slp_skeleton_to_nwb(instance.skeleton) in sleap_io/io/nwb.py. Expect: Correct data integrity and relationship preservation. rg -A 10 -B 10 'slp_skeleton_to_nwb(instance.skeleton)' sleap_io/io/nwb.pyLength of output: 76
Script:
#!/bin/bash # Description: Verify the data integrity and relationship preservation in skeleton conversion. # Test: Search for the function call slp_skeleton_to_nwb in sleap_io/io/nwb.py with broader context. Expect: Correct data integrity and relationship preservation. rg -A 10 -B 10 'slp_skeleton_to_nwb' sleap_io/io/nwb.pyLength of output: 1351
sleap_io/io/nwb.py
Outdated
def write_nwb_training(pose_training: PoseTraining, # type: ignore[return] | ||
nwbfile_path: str, | ||
nwb_file_kwargs: Optional[dict], | ||
pose_estimation_metadata: Optional[dict] = None, | ||
): | ||
"""Writes data from a `PoseTraining` object to an NWB file. | ||
|
||
Args: | ||
pose_training: A `PoseTraining` object. | ||
nwbfile_path: The path where the nwb file is to be written. | ||
""" | ||
nwb_file_kwargs = nwb_file_kwargs or {} | ||
|
||
nwbfile = NWBFile(**nwb_file_kwargs) | ||
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata) | ||
with NWBHDF5IO(str(nwbfile_path), "w") as io: | ||
io.write(nwbfile) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review the function write_nwb_training
.
This function writes data from a PoseTraining
object to an NWB file. The function setup and the handling of optional metadata are done correctly. The use of append_nwb_data
to integrate the training data into the NWB file is appropriate.
However, the function lacks error handling for potential issues during the file writing process. Consider adding error handling mechanisms to improve robustness.
+ try:
nwbfile = NWBFile(**nwb_file_kwargs)
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata)
with NWBHDF5IO(str(nwbfile_path), "w") as io:
io.write(nwbfile)
+ except Exception as e:
+ # Handle potential file writing errors
+ print(f"Error writing NWB file: {e}")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def write_nwb_training(pose_training: PoseTraining, # type: ignore[return] | |
nwbfile_path: str, | |
nwb_file_kwargs: Optional[dict], | |
pose_estimation_metadata: Optional[dict] = None, | |
): | |
"""Writes data from a `PoseTraining` object to an NWB file. | |
Args: | |
pose_training: A `PoseTraining` object. | |
nwbfile_path: The path where the nwb file is to be written. | |
""" | |
nwb_file_kwargs = nwb_file_kwargs or {} | |
nwbfile = NWBFile(**nwb_file_kwargs) | |
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata) | |
with NWBHDF5IO(str(nwbfile_path), "w") as io: | |
io.write(nwbfile) | |
try: | |
nwbfile = NWBFile(**nwb_file_kwargs) | |
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata) | |
with NWBHDF5IO(str(nwbfile_path), "w") as io: | |
io.write(nwbfile) | |
except Exception as e: | |
# Handle potential file writing errors | |
print(f"Error writing NWB file: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Outside diff range, codebase verification and nitpick comments (9)
tests/io/test_nwb.py (9)
Line range hint
45-67
: Review oftest_typical_case_append
This function effectively tests the appending of SLEAP data to NWB files, ensuring that the structure and naming conventions are maintained. It might be beneficial to extract some of the repeated logic (like processing module extraction and naming) into helper functions for better readability and maintainability.
Line range hint
69-91
: Review oftest_typical_case_append_with_metadata_propagation
This test function correctly ensures that metadata is propagated during the data appending process. The use of explicit metadata values for testing is a good practice. However, consider parameterizing these metadata values to facilitate easier updates and expansions of the test cases in the future.
Line range hint
93-106
: Review oftest_provenance_writing
The function effectively tests the propagation of provenance information, ensuring data traceability. The structure of the test is clear, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.
Line range hint
108-130
: Review oftest_default_metadata_overwriting
This test function correctly checks that default metadata values can be overwritten, allowing for customization. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.
Line range hint
132-160
: Review oftest_complex_case_append
This function effectively handles the appending of complex case data, ensuring correct structure and naming within the NWB file. The test is comprehensive, but consider extracting some of the logic into helper functions for better readability and maintainability.
Line range hint
162-194
: Review oftest_complex_case_append_with_timestamps_metadata
This test function correctly ensures that timestamps metadata is propagated and used accurately. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.
Line range hint
196-203
: Review oftest_assertion_with_no_predicted_instance
This test function correctly handles the scenario where no predicted instances are found, ensuring robust error handling. The use of explicit error matching is a good practice. Consider adding more detailed comments to explain the significance of this test for future maintainers.
Line range hint
205-216
: Review oftest_typical_case_write
This test function ensures that typical case data is correctly written to an NWB file. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.
Line range hint
218-235
: Review oftest_get_timestamps
This test function correctly retrieves timestamps from data series, ensuring accurate data handling. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- tests/io/test_nwb.py (2 hunks)
Additional comments not posted (2)
tests/io/test_nwb.py (2)
9-15
: Review of new imports and functionsThe imports and new function declarations for
pose_training_to_labels
andlabels_to_pose_training
are correctly placed and follow the Python convention. However, ensure these functions are used appropriately across the codebase and that their implementation is aligned with the project's standards.
33-43
: Review oftest_nwb_slp_conversion
This test function correctly checks the integrity of data through conversions. It's important to ensure that not only the count of instances and labeled frames are maintained, but also that the content of these structures is unchanged. Consider adding assertions for deeper properties of the instances if not already covered elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- tests/io/test_nwb.py (2 hunks)
Files skipped from review as they are similar to previous changes (1)
- tests/io/test_nwb.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sleap_io/io/nwb.py
Outdated
def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return] | ||
"""Append a PoseTraining object to an existing NWB data file. | ||
|
||
Args: | ||
pose_training: A PoseTraining object. | ||
nwbfile_path: The path to the NWB file. | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement the append_nwb_training
function.
Currently, the function raises NotImplementedError
. This needs to be addressed either by implementing the function or by documenting it as a known limitation.
- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return] | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
raise NotImplementedError | |
def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return] | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
# TODO: Implement or document this function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (9 hunks)
Files skipped from review as they are similar to previous changes (1)
- sleap_io/io/nwb.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- save_test.ipynb (1 hunks)
- sleap_io/io/main.py (3 hunks)
Additional comments not posted (1)
sleap_io/io/main.py (1)
50-63
: Remove redundant check foras_training
.The previous review comment suggested removing the redundant check for
as_training is None
, which seems to have been addressed. However, the logic can be further simplified by directly returning the function based on theas_training
flag.def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels: """Load an NWB dataset as a SLEAP `Labels` object. Args: filename: Path to a NWB file (`.nwb`). as_training: If `True`, load the dataset as a training dataset. Returns: The dataset as a `Labels` object. """ - if as_training: - return nwb.read_nwb_training(filename) - else: - return nwb.read_nwb(filename) + return nwb.read_nwb_training(filename) if as_training else nwb.read_nwb(filename)
sleap_io/io/main.py
Outdated
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): | ||
"""Save a SLEAP dataset to NWB format. | ||
|
||
Args: | ||
labels: A SLEAP `Labels` object (see `load_slp`). | ||
filename: Path to NWB file to save to. Must end in `.nwb`. | ||
as_training: If `True`, save the dataset as a training dataset. | ||
append: If `True` (the default), append to existing NWB file. File will be | ||
created if it does not exist. | ||
|
||
See also: nwb.write_nwb, nwb.append_nwb | ||
""" | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename) | ||
else: | ||
nwb.write_nwb(labels, filename) | ||
if as_training: | ||
pose_training = nwb.labels_to_pose_training(labels, **kwargs) | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb_training(pose_training, filename, **kwargs) | ||
else: | ||
nwb.write_nwb_training(pose_training, filename, **kwargs) | ||
|
||
else: | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename, **kwargs) | ||
else: | ||
nwb.write_nwb(labels, filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor conditional logic in save_nwb
.
The function save_nwb
handles saving of NWB datasets with an optional training mode. The logic can be refactored to reduce the depth of conditionals and improve readability.
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
- if as_training:
- pose_training = nwb.labels_to_pose_training(labels, **kwargs)
- if append and Path(filename).exists():
- nwb.append_nwb_training(pose_training, filename, **kwargs)
- else:
- nwb.write_nwb_training(pose_training, filename, **kwargs)
-
- else:
- if append and Path(filename).exists():
- nwb.append_nwb(labels, filename, **kwargs)
- else:
- nwb.write_nwb(labels, filename)
+ func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+ action = nwb.append_nwb_training if as_training else nwb.append_nwb
+ write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+ data = func(labels, **kwargs)
+ if append and Path(filename).exists():
+ action(data, filename, **kwargs)
+ else:
+ write(data, filename, **kwargs)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): | |
"""Save a SLEAP dataset to NWB format. | |
Args: | |
labels: A SLEAP `Labels` object (see `load_slp`). | |
filename: Path to NWB file to save to. Must end in `.nwb`. | |
as_training: If `True`, save the dataset as a training dataset. | |
append: If `True` (the default), append to existing NWB file. File will be | |
created if it does not exist. | |
See also: nwb.write_nwb, nwb.append_nwb | |
""" | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename) | |
else: | |
nwb.write_nwb(labels, filename) | |
if as_training: | |
pose_training = nwb.labels_to_pose_training(labels, **kwargs) | |
if append and Path(filename).exists(): | |
nwb.append_nwb_training(pose_training, filename, **kwargs) | |
else: | |
nwb.write_nwb_training(pose_training, filename, **kwargs) | |
else: | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename, **kwargs) | |
else: | |
nwb.write_nwb(labels, filename) | |
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): | |
"""Save a SLEAP dataset to NWB format. | |
Args: | |
labels: A SLEAP `Labels` object (see `load_slp`). | |
filename: Path to NWB file to save to. Must end in `.nwb`. | |
as_training: If `True`, save the dataset as a training dataset. | |
append: If `True` (the default), append to existing NWB file. File will be | |
created if it does not exist. | |
See also: nwb.write_nwb, nwb.append_nwb | |
""" | |
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l | |
action = nwb.append_nwb_training if as_training else nwb.append_nwb | |
write = nwb.write_nwb_training if as_training else nwb.write_nwb | |
data = func(labels, **kwargs) | |
if append and Path(filename).exists(): | |
action(data, filename, **kwargs) | |
else: | |
write(data, filename, **kwargs) |
save_test.ipynb
Outdated
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "KeyError", | ||
"evalue": "'tests/data/slp/minimal_instance.pkg.slp'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[3], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01msio\u001b[39;00m\n\u001b[1;32m 3\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", | ||
"File \u001b[0;32m~/salk/io_fork/sleap_io/model/labels.py:372\u001b[0m, in \u001b[0;36mLabels.save\u001b[0;34m(self, filename, format, embed, **kwargs)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Save labels to file in specified format.\u001b[39;00m\n\u001b[1;32m 349\u001b[0m \n\u001b[1;32m 350\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m This argument is only valid for the SLP backend.\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m save_file\n\u001b[0;32m--> 372\u001b[0m \u001b[43msave_file\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | ||
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:241\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabelstudio\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 240\u001b[0m save_labelstudio(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 241\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mjabs\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 242\u001b[0m pose_version \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpose_version\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m5\u001b[39m)\n\u001b[1;32m 243\u001b[0m save_jabs(labels, pose_version, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", | ||
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:92\u001b[0m, in \u001b[0;36msave_nwb\u001b[0;34m(labels, filename, as_training, append, **kwargs)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 89\u001b[0m nwb\u001b[38;5;241m.\u001b[39mwrite_nwb(labels, filename)\n\u001b[0;32m---> 92\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_labelstudio\u001b[39m(\n\u001b[1;32m 93\u001b[0m filename: \u001b[38;5;28mstr\u001b[39m, skeleton: Optional[Union[Skeleton, \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m]]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 94\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Labels:\n\u001b[1;32m 95\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Read Label Studio-style annotations from a file and return a `Labels` object.\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \n\u001b[1;32m 97\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;124;03m Parsed labels as a `Labels` instance.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m labelstudio\u001b[38;5;241m.\u001b[39mread_labels(filename, skeleton\u001b[38;5;241m=\u001b[39mskeleton)\n", | ||
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:388\u001b[0m, in \u001b[0;36mwrite_nwb\u001b[0;34m(labels, nwbfile_path, nwb_file_kwargs, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrite_nwb\u001b[39m(\n\u001b[1;32m 355\u001b[0m labels: Labels,\n\u001b[1;32m 356\u001b[0m nwbfile_path: \u001b[38;5;28mstr\u001b[39m,\n\u001b[1;32m 357\u001b[0m nwb_file_kwargs: Optional[\u001b[38;5;28mdict\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 358\u001b[0m pose_estimation_metadata: Optional[\u001b[38;5;28mdict\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 359\u001b[0m ):\n\u001b[1;32m 360\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Write labels to an nwb file and save it to the nwbfile_path given.\u001b[39;00m\n\u001b[1;32m 361\u001b[0m \n\u001b[1;32m 362\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[1;32m 363\u001b[0m \u001b[38;5;124;03m labels: A general `Labels` object.\u001b[39;00m\n\u001b[1;32m 364\u001b[0m \u001b[38;5;124;03m nwbfile_path: The path where the nwb file is to be written.\u001b[39;00m\n\u001b[1;32m 365\u001b[0m \u001b[38;5;124;03m nwb_file_kwargs: A dict containing metadata to the nwbfile. Example:\u001b[39;00m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;124;03m nwb_file_kwargs = {\u001b[39;00m\n\u001b[1;32m 367\u001b[0m \u001b[38;5;124;03m 'session_description: 'your_session_description',\u001b[39;00m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m 'identifier': 'your session_identifier',\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m }\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;124;03m For a full list of possible values see:\u001b[39;00m\n\u001b[1;32m 371\u001b[0m \u001b[38;5;124;03m https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile\u001b[39;00m\n\u001b[1;32m 372\u001b[0m \n\u001b[1;32m 373\u001b[0m \u001b[38;5;124;03m Defaults to None and default values are used to generate the nwb file.\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \n\u001b[1;32m 375\u001b[0m \u001b[38;5;124;03m pose_estimation_metadata: This argument has a dual purpose:\u001b[39;00m\n\u001b[1;32m 376\u001b[0m \n\u001b[1;32m 377\u001b[0m \u001b[38;5;124;03m 1) It can be used to pass time information about the video which is\u001b[39;00m\n\u001b[1;32m 378\u001b[0m \u001b[38;5;124;03m necessary for synchronizing frames in pose estimation tracking to other\u001b[39;00m\n\u001b[1;32m 379\u001b[0m \u001b[38;5;124;03m modalities. Either the video timestamps can be passed to\u001b[39;00m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;124;03m This can be used to pass the timestamps with the key `video_timestamps`\u001b[39;00m\n\u001b[1;32m 381\u001b[0m \u001b[38;5;124;03m or the sampling rate with key`video_sample_rate`.\u001b[39;00m\n\u001b[1;32m 382\u001b[0m \n\u001b[1;32m 383\u001b[0m \u001b[38;5;124;03m e.g. pose_estimation_metadata[\"video_timestamps\"] = np.array(timestamps)\u001b[39;00m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;124;03m or pose_estimation_metadata[\"video_sample_rate\"] = 15 # In Hz\u001b[39;00m\n\u001b[1;32m 385\u001b[0m \n\u001b[1;32m 386\u001b[0m \u001b[38;5;124;03m 2) The other use of this dictionary is to ovewrite sleap-io default\u001b[39;00m\n\u001b[1;32m 387\u001b[0m \u001b[38;5;124;03m arguments for the PoseEstimation container.\u001b[39;00m\n\u001b[0;32m--> 388\u001b[0m \u001b[38;5;124;03m see https://github.com/rly/ndx-pose for a full list or arguments.\u001b[39;00m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 390\u001b[0m nwb_file_kwargs \u001b[38;5;241m=\u001b[39m nwb_file_kwargs \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mdict\u001b[39m()\n\u001b[1;32m 392\u001b[0m \u001b[38;5;66;03m# Add required values for nwbfile if not present\u001b[39;00m\n", | ||
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:471\u001b[0m, in \u001b[0;36mappend_nwb_data\u001b[0;34m(labels, nwbfile, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 469\u001b[0m labels_data_df = convert_predictions_to_dataframe(labels)\n\u001b[1;32m 470\u001b[0m break\n\u001b[0;32m--> 471\u001b[0m else:\n\u001b[1;32m 472\u001b[0m labels_data_df = pd.DataFrame()\n\u001b[1;32m 474\u001b[0m # For every video create a processing module\n", | ||
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n", | ||
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/pandas/core/indexes/range.py:417\u001b[0m, in \u001b[0;36mRangeIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, Hashable):\n\u001b[0;32m--> 417\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n\u001b[1;32m 419\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n", | ||
"\u001b[0;31mKeyError\u001b[0m: 'tests/data/slp/minimal_instance.pkg.slp'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import sleap_io as sio\n", | ||
"\n", | ||
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n", | ||
"labels_original.save(\"minimal_instance.pkg.nwb\", format=\"nwb_training\")\n", | ||
"labels_loaded = sio.load_nwb(\"minimal_instance.pkg.nwb\")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle file loading and saving errors more gracefully.
The notebook cell attempts to load and save files, but it results in a KeyError
. This suggests that error handling needs to be improved to provide more informative messages or to ensure that the file paths are correct.
try:
labels_original = sio.load_slp("tests/data/slp/minimal_instance.pkg.slp")
labels_original.save("minimal_instance.pkg.nwb", format="nwb_training")
labels_loaded = sio.load_nwb("minimal_instance.pkg.nwb")
except Exception as e:
print(f"Error occurred: {e}")
save_test.ipynb
Outdated
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"minimal_instance.pkg.nwb_images/img_0.png\n" | ||
] | ||
}, | ||
{ | ||
"ename": "ValueError", | ||
"evalue": "Can't write images with one color channel.", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[2], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m img_path \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimg_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_path)\n\u001b[0;32m---> 16\u001b[0m \u001b[43miio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 17\u001b[0m img_paths\u001b[38;5;241m.\u001b[39mappend(img_path)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_paths)\n", | ||
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/v3.py:147\u001b[0m, in \u001b[0;36mimwrite\u001b[0;34m(uri, image, plugin, extension, format_hint, **kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Write an ndimage to the given URI.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03mThe exact behavior depends on the file type and plugin used. To learn about\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 136\u001b[0m \n\u001b[1;32m 137\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m imopen(\n\u001b[1;32m 140\u001b[0m uri,\n\u001b[1;32m 141\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m extension\u001b[38;5;241m=\u001b[39mextension,\n\u001b[1;32m 146\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m img_file:\n\u001b[0;32m--> 147\u001b[0m encoded \u001b[38;5;241m=\u001b[39m \u001b[43mimg_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m encoded\n", | ||
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/plugins/pillow.py:433\u001b[0m, in \u001b[0;36mPillowPlugin.write\u001b[0;34m(self, ndimage, mode, format, is_batch, **kwargs)\u001b[0m\n\u001b[1;32m 431\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 432\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 433\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt write images with one color channel.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m4\u001b[39m]:\n\u001b[1;32m 435\u001b[0m \u001b[38;5;66;03m# Note: this makes a channel-last assumption\u001b[39;00m\n\u001b[1;32m 436\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", | ||
"\u001b[0;31mValueError\u001b[0m: Can't write images with one color channel." | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import imageio.v3 as iio\n", | ||
"from pathlib import Path\n", | ||
"import sleap_io as sio\n", | ||
"\n", | ||
"save_path = Path(\"minimal_instance.pkg.nwb_images\")\n", | ||
"try:\n", | ||
" save_path.mkdir(parents=True, exist_ok=True)\n", | ||
"except Exception as e:\n", | ||
" print(f\"An error {e} occurred. The directory could not be created.\")\n", | ||
"img_paths = []\n", | ||
"\n", | ||
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n", | ||
"for i, lf in enumerate(labels_original):\n", | ||
" img_path = save_path / f\"img_{i}.png\"\n", | ||
" print(img_path)\n", | ||
" iio.imwrite(img_path, lf.image)\n", | ||
" img_paths.append(img_path)\n", | ||
"print(img_paths)" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve error handling for image writing.
This cell attempts to write images, but fails due to an unsupported single color channel. This issue should be caught and handled, or the notebook should ensure that only supported image formats are processed.
try:
for i, lf in enumerate(labels_original):
img_path = save_path / f"img_{i}.png"
print(img_path)
iio.imwrite(img_path, lf.image)
img_paths.append(img_path)
except ValueError as e:
print(f"Failed to write image: {e}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (3)
- save_test.ipynb (1 hunks)
- sleap_io/io/main.py (3 hunks)
- sleap_io/io/nwb.py (10 hunks)
Files skipped from review as they are similar to previous changes (1)
- save_test.ipynb
Additional context used
Ruff
sleap_io/io/main.py
76-76: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
sleap_io/io/nwb.py
519-519: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
520-520: Local variable
provenance
is assigned to but never usedRemove assignment to unused variable
provenance
(F841)
sleap_io/io/main.py
Outdated
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels: | ||
"""Load an NWB dataset as a SLEAP `Labels` object. | ||
|
||
Args: | ||
filename: Path to a NWB file (`.nwb`). | ||
as_training: If `True`, load the dataset as a training dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify conditional logic in load_nwb
.
The function now includes an as_training
parameter to load datasets as training data. However, the logic can be simplified to avoid redundancy.
- if as_training is None:
- return nwb.read_nwb(filename)
-
if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)
Committable suggestion was skipped due to low confidence.
sleap_io/io/main.py
Outdated
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): | ||
"""Save a SLEAP dataset to NWB format. | ||
|
||
Args: | ||
labels: A SLEAP `Labels` object (see `load_slp`). | ||
filename: Path to NWB file to save to. Must end in `.nwb`. | ||
as_training: If `True`, save the dataset as a training dataset. | ||
append: If `True` (the default), append to existing NWB file. File will be | ||
created if it does not exist. | ||
|
||
See also: nwb.write_nwb, nwb.append_nwb | ||
""" | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename) | ||
else: | ||
nwb.write_nwb(labels, filename) | ||
if as_training: | ||
pose_training = nwb.labels_to_pose_training(labels, **kwargs) | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename) | ||
else: | ||
nwb.write_nwb(labels, filename) | ||
|
||
else: | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename, **kwargs) | ||
else: | ||
nwb.write_nwb(labels, filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor save_nwb
to reduce complexity and address previous comments.
The function has been updated to handle the as_training
parameter. However, the current implementation is complex and can be simplified. Additionally, address the previous comment about reducing the depth of conditionals.
- if as_training:
- pose_training = nwb.labels_to_pose_training(labels, **kwargs)
- if append and Path(filename).exists():
- nwb.append_nwb(labels, filename)
- else:
- nwb.write_nwb(labels, filename)
-
- else:
- if append and Path(filename).exists():
- nwb.append_nwb(labels, filename, **kwargs)
- else:
- nwb.write_nwb(labels, filename)
+ func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+ action = nwb.append_nwb_training if as_training else nwb.append_nwb
+ write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+ data = func(labels, **kwargs)
+ if append and Path(filename).exists():
+ action(data, filename, **kwargs)
+ else:
+ write(data, filename, **kwargs)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): | |
"""Save a SLEAP dataset to NWB format. | |
Args: | |
labels: A SLEAP `Labels` object (see `load_slp`). | |
filename: Path to NWB file to save to. Must end in `.nwb`. | |
as_training: If `True`, save the dataset as a training dataset. | |
append: If `True` (the default), append to existing NWB file. File will be | |
created if it does not exist. | |
See also: nwb.write_nwb, nwb.append_nwb | |
""" | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename) | |
else: | |
nwb.write_nwb(labels, filename) | |
if as_training: | |
pose_training = nwb.labels_to_pose_training(labels, **kwargs) | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename) | |
else: | |
nwb.write_nwb(labels, filename) | |
else: | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename, **kwargs) | |
else: | |
nwb.write_nwb(labels, filename) | |
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): | |
"""Save a SLEAP dataset to NWB format. | |
Args: | |
labels: A SLEAP `Labels` object (see `load_slp`). | |
filename: Path to NWB file to save to. Must end in `.nwb`. | |
as_training: If `True`, save the dataset as a training dataset. | |
append: If `True` (the default), append to existing NWB file. File will be | |
created if it does not exist. | |
See also: nwb.write_nwb, nwb.append_nwb | |
""" | |
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l | |
action = nwb.append_nwb_training if as_training else nwb.append_nwb | |
write = nwb.write_nwb_training if as_training else nwb.write_nwb | |
data = func(labels, **kwargs) | |
if append and Path(filename).exists(): | |
action(data, filename, **kwargs) | |
else: | |
write(data, filename, **kwargs) |
Tools
Ruff
76-76: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
Outside diff range, codebase verification and nitpick comments (8)
sleap_io/io/nwb.py (8)
Line range hint
238-357
: Break downread_nwb
and add error handling.The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.
def read_nwb(path: str) -> Labels: + if not isinstance(path, str) or not path.endswith(".nwb"): + raise ValueError("Invalid path: Expected a .nwb file.")
Line range hint
359-414
: Add error handling and type annotations towrite_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def write_nwb( labels: Labels, nwbfile_path: str, nwb_file_kwargs: Optional[dict] = None, pose_estimation_metadata: Optional[dict] = None, as_training: Optional[bool] = None, ): + if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.")
Line range hint
417-491
: Break downappend_nwb_data
and add error handling.The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.
def append_nwb_data( labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None ) -> NWBFile: + if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected a Labels object and an NWBFile object.")
Line range hint
510-536
: Add error handling and type annotations toappend_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb( labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None, as_training: Optional[bool] = None ): + if not isinstance(labels, Labels) or not isinstance(filename, str): + raise ValueError("Invalid input: Expected a Labels object and a filename string.")
Line range hint
539-552
: Add error handling and type annotations toget_processing_module_for_video
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def get_processing_module_for_video( processing_module_name: str, nwbfile: NWBFile ) -> ProcessingModule: + if not isinstance(processing_module_name, str) or not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected a processing module name string and an NWBFile object.")
Line range hint
555-636
: Break downbuild_pose_estimation_container_for_track
and add error handling.The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.
def build_pose_estimation_container_for_track( labels_data_df: pd.DataFrame, labels: Labels, track_name: str, video: Video, pose_estimation_metadata: dict, ) -> PoseEstimation: + if not isinstance(labels_data_df, pd.DataFrame) or not isinstance(labels, Labels) or not isinstance(track_name, str) or not isinstance(video, Video): + raise ValueError("Invalid input: Expected a DataFrame, Labels object, track name string, and Video object.")
Line range hint
638-686
: Add error handling and type annotations tobuild_track_pose_estimation_list
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_track_pose_estimation_list( track_data_df: pd.DataFrame, timestamps: ArrayLike ) -> List[PoseEstimationSeries]: + if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, np.ndarray): + raise ValueError("Invalid input: Expected a DataFrame and an ndarray.")
Line range hint
238-246
: Add error handling and type annotations toget_timestamps
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def get_timestamps(series: PoseEstimationSeries) -> np.ndarray: + if series is None or not hasattr(series, 'timestamps') or not hasattr(series, 'data'): + raise ValueError("Invalid series: Missing required data.")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- sleap_io/io/main.py (3 hunks)
- sleap_io/io/nwb.py (11 hunks)
Files skipped from review as they are similar to previous changes (1)
- sleap_io/io/main.py
Additional context used
Ruff
sleap_io/io/nwb.py
504-504: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
505-505: Local variable
provenance
is assigned to but never usedRemove assignment to unused variable
provenance
(F841)
sleap_io/io/nwb.py
Outdated
def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return] | ||
"""Append a PoseTraining object to an existing NWB data file. | ||
|
||
Args: | ||
pose_training: A PoseTraining object. | ||
nwbfile_path: The path to the NWB file. | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
pose_training = labels_to_pose_training(labels) | ||
provenance = labels.provenance | ||
raise NotImplementedError | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement or document the append_nwb_training
function.
Currently, the function raises NotImplementedError
. This needs to be addressed either by implementing the function or by documenting it as a known limitation.
- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return] | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_training = labels_to_pose_training(labels) | |
provenance = labels.provenance | |
raise NotImplementedError | |
def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return] | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_training = labels_to_pose_training(labels) | |
provenance = labels.provenance | |
# TODO: Implement or document this function |
Tools
Ruff
504-504: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
505-505: Local variable
provenance
is assigned to but never usedRemove assignment to unused variable
provenance
(F841)
sleap_io/io/nwb.py
Outdated
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | ||
"""Converts an NWB skeleton to a SLEAP skeleton. | ||
|
||
Args: | ||
skeleton: An NWB skeleton. | ||
|
||
Returns: | ||
A SLEAP skeleton. | ||
""" | ||
nodes = [Node(name=node) for node in skeleton.nodes] | ||
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | ||
return SLEAPSkeleton( | ||
nodes=nodes, | ||
edges=edges, | ||
name=skeleton.name, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to nwb_skeleton_to_sleap
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+ if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+ raise ValueError("Invalid skeleton: Missing required data.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | |
"""Converts an NWB skeleton to a SLEAP skeleton. | |
Args: | |
skeleton: An NWB skeleton. | |
Returns: | |
A SLEAP skeleton. | |
""" | |
nodes = [Node(name=node) for node in skeleton.nodes] | |
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | |
return SLEAPSkeleton( | |
nodes=nodes, | |
edges=edges, | |
name=skeleton.name, | |
) | |
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: | |
if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'): | |
raise ValueError("Invalid skeleton: Missing required data.") | |
"""Converts an NWB skeleton to a SLEAP skeleton. | |
Args: | |
skeleton: An NWB skeleton. | |
Returns: | |
A SLEAP skeleton. | |
""" | |
nodes = [Node(name=node) for node in skeleton.nodes] | |
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | |
return SLEAPSkeleton( | |
nodes=nodes, | |
edges=edges, | |
name=skeleton.name, | |
) |
sleap_io/io/nwb.py
Outdated
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs): | ||
"""Write a SLEAP package to an NWB file. | ||
|
||
Args: | ||
filename: The path to the SLEAP package. | ||
labels: The SLEAP Labels object. | ||
""" | ||
assert filename.endswith(".pkg.slp") | ||
|
||
path = filename.split(".slp")[0] | ||
save_path = Path(path + ".nwb_images") | ||
img_paths = [] | ||
for i, labeled_frame in enumerate(labels.labeled_frames): | ||
img_path = save_path / f"frame_{i}.png" | ||
imwrite(img_path, labeled_frame.image) | ||
img_paths.append(img_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to sleap_pkg_to_nwb
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
+ if not filename.endswith(".pkg.slp"):
+ raise ValueError("Invalid filename: Expected a .pkg.slp file.")
Committable suggestion was skipped due to low confidence.
sleap_io/io/nwb.py
Outdated
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return] | ||
"""Converts a list of SLEAP Videos to NWB SourceVideos. | ||
|
||
Args: | ||
videos: A list of SLEAP Videos. | ||
|
||
Returns: | ||
An NWB SourceVideos object. | ||
""" | ||
source_videos = [] | ||
for i, video in enumerate(videos): | ||
image_series = ImageSeries( | ||
name=f"video_{i}", | ||
description="Video file", | ||
unit="NA", | ||
format="external", | ||
external_file=[video.filename], | ||
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | ||
starting_frame=[0], | ||
rate=30.0, # TODO - change to `video.backend.fps` when available | ||
) | ||
source_videos.append(image_series) | ||
return SourceVideos(image_series=source_videos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to videos_to_source_videos
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def videos_to_source_videos(videos: List[Video]) -> SourceVideos:
+ if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+ raise ValueError("Invalid input: Expected a list of Video objects.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return] | |
"""Converts a list of SLEAP Videos to NWB SourceVideos. | |
Args: | |
videos: A list of SLEAP Videos. | |
Returns: | |
An NWB SourceVideos object. | |
""" | |
source_videos = [] | |
for i, video in enumerate(videos): | |
image_series = ImageSeries( | |
name=f"video_{i}", | |
description="Video file", | |
unit="NA", | |
format="external", | |
external_file=[video.filename], | |
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | |
starting_frame=[0], | |
rate=30.0, # TODO - change to `video.backend.fps` when available | |
) | |
source_videos.append(image_series) | |
return SourceVideos(image_series=source_videos) | |
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return] | |
"""Converts a list of SLEAP Videos to NWB SourceVideos. | |
Args: | |
videos: A list of SLEAP Videos. | |
Returns: | |
An NWB SourceVideos object. | |
""" | |
if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos): | |
raise ValueError("Invalid input: Expected a list of Video objects.") | |
source_videos = [] | |
for i, video in enumerate(videos): | |
image_series = ImageSeries( | |
name=f"video_{i}", | |
description="Video file", | |
unit="NA", | |
format="external", | |
external_file=[video.filename], | |
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | |
starting_frame=[0], | |
rate=30.0, # TODO - change to `video.backend.fps` when available | |
) | |
source_videos.append(image_series) | |
return SourceVideos(image_series=source_videos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Outside diff range, codebase verification and nitpick comments (1)
sleap_io/io/nwb.py (1)
Line range hint
641-678
:
Add error handling and type annotations tobuild_track_pose_estimation_list
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_track_pose_estimation_list( track_data_df: pd.DataFrame, timestamps: ArrayLike # type: ignore[return] ) -> List[PoseEstimationSeries]: + if not isinstance(track_data_df, pd.DataFrame): + raise ValueError("Invalid input: Expected a pandas DataFrame.")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- sleap_io/io/main.py (3 hunks)
- sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
505-505: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
506-506: Local variable
provenance
is assigned to but never usedRemove assignment to unused variable
provenance
(F841)
sleap_io/io/main.py
Outdated
def save_nwb( | ||
labels: Labels, | ||
filename: str, | ||
as_training: bool = None, | ||
append: bool = True, | ||
**kwargs, | ||
): | ||
"""Save a SLEAP dataset to NWB format. | ||
|
||
Args: | ||
labels: A SLEAP `Labels` object (see `load_slp`). | ||
filename: Path to NWB file to save to. Must end in `.nwb`. | ||
as_training: If `True`, save the dataset as a training dataset. | ||
append: If `True` (the default), append to existing NWB file. File will be | ||
created if it does not exist. | ||
|
||
See also: nwb.write_nwb, nwb.append_nwb | ||
""" | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename) | ||
if as_training: | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb_training(labels, filename) | ||
else: | ||
nwb.write_nwb(labels, filename, None, None, True) | ||
|
||
else: | ||
nwb.write_nwb(labels, filename) | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename, **kwargs) | ||
else: | ||
nwb.write_nwb(labels, filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor save_nwb
to reduce complexity and address previous comments.
The function has been updated to handle the as_training
parameter. However, the current implementation is complex and can be simplified. Additionally, address the previous comment about reducing the depth of conditionals.
- if as_training:
- pose_training = nwb.labels_to_pose_training(labels, **kwargs)
- if append and Path(filename).exists():
- nwb.append_nwb_training(pose_training, filename, **kwargs)
- else:
- nwb.write_nwb_training(pose_training, filename, **kwargs)
-
- else:
- if append and Path(filename).exists():
- nwb.append_nwb(labels, filename, **kwargs)
- else:
- nwb.write_nwb(labels, filename)
+ func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+ action = nwb.append_nwb_training if as_training else nwb.append_nwb
+ write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+ data = func(labels, **kwargs)
+ if append and Path(filename).exists():
+ action(data, filename, **kwargs)
+ else:
+ write(data, filename, **kwargs)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def save_nwb( | |
labels: Labels, | |
filename: str, | |
as_training: bool = None, | |
append: bool = True, | |
**kwargs, | |
): | |
"""Save a SLEAP dataset to NWB format. | |
Args: | |
labels: A SLEAP `Labels` object (see `load_slp`). | |
filename: Path to NWB file to save to. Must end in `.nwb`. | |
as_training: If `True`, save the dataset as a training dataset. | |
append: If `True` (the default), append to existing NWB file. File will be | |
created if it does not exist. | |
See also: nwb.write_nwb, nwb.append_nwb | |
""" | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename) | |
if as_training: | |
if append and Path(filename).exists(): | |
nwb.append_nwb_training(labels, filename) | |
else: | |
nwb.write_nwb(labels, filename, None, None, True) | |
else: | |
nwb.write_nwb(labels, filename) | |
if append and Path(filename).exists(): | |
nwb.append_nwb(labels, filename, **kwargs) | |
else: | |
nwb.write_nwb(labels, filename) | |
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l | |
action = nwb.append_nwb_training if as_training else nwb.append_nwb | |
write = nwb.write_nwb_training if as_training else nwb.write_nwb | |
data = func(labels, **kwargs) | |
if append and Path(filename).exists(): | |
action(data, filename, **kwargs) | |
else: | |
write(data, filename, **kwargs) |
sleap_io/io/nwb.py
Outdated
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return] | ||
"""Converts a list of SLEAP Videos to NWB SourceVideos. | ||
|
||
Args: | ||
videos: A list of SLEAP Videos. | ||
|
||
Returns: | ||
An NWB SourceVideos object. | ||
""" | ||
source_videos = [] | ||
for i, video in enumerate(videos): | ||
image_series = ImageSeries( | ||
name=f"video_{i}", | ||
description="Video file", | ||
unit="NA", | ||
format="external", | ||
external_file=[video.filename], | ||
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | ||
starting_frame=[0], | ||
rate=30.0, # TODO - change to `video.backend.fps` when available | ||
) | ||
source_videos.append(image_series) | ||
return SourceVideos(image_series=source_videos) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to videos_to_source_videos
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def videos_to_source_videos(videos: List[Video]) -> SourceVideos:
+ if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+ raise ValueError("Invalid input: Expected a list of Video objects.")
Committable suggestion was skipped due to low confidence.
sleap_io/io/nwb.py
Outdated
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return] | ||
"""Creates an NWB PoseTraining object from a Labels object. | ||
|
||
Args: | ||
labels: A Labels object. | ||
filename: The filename of the source video. | ||
|
||
Returns: | ||
A PoseTraining object. | ||
""" | ||
training_frame_list = [] | ||
for i, labeled_frame in enumerate(labels.labeled_frames): | ||
training_frame_name = name_generator("training_frame") | ||
training_frame_annotator = f"{training_frame_name}{i}" | ||
skeleton_instances_list = [] | ||
for instance in labeled_frame.instances: | ||
if isinstance(instance, PredictedInstance): | ||
continue | ||
skeleton_instance = instance_to_skeleton_instance(instance) | ||
skeleton_instances_list.append(skeleton_instance) | ||
|
||
training_frame_skeleton_instances = SkeletonInstances( | ||
skeleton_instances=skeleton_instances_list | ||
) | ||
training_frame_video = labeled_frame.video | ||
training_frame_video_index = labeled_frame.frame_idx | ||
training_frame = TrainingFrame( | ||
name=training_frame_name, | ||
annotator=training_frame_annotator, | ||
skeleton_instances=training_frame_skeleton_instances, | ||
source_video=ImageSeries( | ||
name=training_frame_name, | ||
description=training_frame_annotator, | ||
unit="NA", | ||
format="external", | ||
external_file=[training_frame_video.filename], | ||
dimension=[ | ||
training_frame_video.shape[1], | ||
training_frame_video.shape[2], | ||
], | ||
starting_frame=[0], | ||
rate=30.0, | ||
), | ||
source_video_frame_index=training_frame_video_index, | ||
) | ||
training_frame_list.append(training_frame) | ||
|
||
training_frames = TrainingFrames(training_frames=training_frame_list) | ||
pose_training = PoseTraining( | ||
training_frames=training_frames, | ||
source_videos=videos_to_source_videos(labels.videos), | ||
) | ||
return pose_training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Break down labels_to_pose_training
and add error handling.
The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining:
+ if labels is None:
+ raise ValueError("Labels object cannot be None.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return] | |
"""Creates an NWB PoseTraining object from a Labels object. | |
Args: | |
labels: A Labels object. | |
filename: The filename of the source video. | |
Returns: | |
A PoseTraining object. | |
""" | |
training_frame_list = [] | |
for i, labeled_frame in enumerate(labels.labeled_frames): | |
training_frame_name = name_generator("training_frame") | |
training_frame_annotator = f"{training_frame_name}{i}" | |
skeleton_instances_list = [] | |
for instance in labeled_frame.instances: | |
if isinstance(instance, PredictedInstance): | |
continue | |
skeleton_instance = instance_to_skeleton_instance(instance) | |
skeleton_instances_list.append(skeleton_instance) | |
training_frame_skeleton_instances = SkeletonInstances( | |
skeleton_instances=skeleton_instances_list | |
) | |
training_frame_video = labeled_frame.video | |
training_frame_video_index = labeled_frame.frame_idx | |
training_frame = TrainingFrame( | |
name=training_frame_name, | |
annotator=training_frame_annotator, | |
skeleton_instances=training_frame_skeleton_instances, | |
source_video=ImageSeries( | |
name=training_frame_name, | |
description=training_frame_annotator, | |
unit="NA", | |
format="external", | |
external_file=[training_frame_video.filename], | |
dimension=[ | |
training_frame_video.shape[1], | |
training_frame_video.shape[2], | |
], | |
starting_frame=[0], | |
rate=30.0, | |
), | |
source_video_frame_index=training_frame_video_index, | |
) | |
training_frame_list.append(training_frame) | |
training_frames = TrainingFrames(training_frames=training_frame_list) | |
pose_training = PoseTraining( | |
training_frames=training_frames, | |
source_videos=videos_to_source_videos(labels.videos), | |
) | |
return pose_training | |
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return] | |
"""Creates an NWB PoseTraining object from a Labels object. | |
Args: | |
labels: A Labels object. | |
filename: The filename of the source video. | |
Returns: | |
A PoseTraining object. | |
""" | |
if labels is None: | |
raise ValueError("Labels object cannot be None.") | |
training_frame_list = [] | |
for i, labeled_frame in enumerate(labels.labeled_frames): | |
training_frame_name = name_generator("training_frame") | |
training_frame_annotator = f"{training_frame_name}{i}" | |
skeleton_instances_list = [] | |
for instance in labeled_frame.instances: | |
if isinstance(instance, PredictedInstance): | |
continue | |
skeleton_instance = instance_to_skeleton_instance(instance) | |
skeleton_instances_list.append(skeleton_instance) | |
training_frame_skeleton_instances = SkeletonInstances( | |
skeleton_instances=skeleton_instances_list | |
) | |
training_frame_video = labeled_frame.video | |
training_frame_video_index = labeled_frame.frame_idx | |
training_frame = TrainingFrame( | |
name=training_frame_name, | |
annotator=training_frame_annotator, | |
skeleton_instances=training_frame_skeleton_instances, | |
source_video=ImageSeries( | |
name=training_frame_name, | |
description=training_frame_annotator, | |
unit="NA", | |
format="external", | |
external_file=[training_frame_video.filename], | |
dimension=[ | |
training_frame_video.shape[1], | |
training_frame_video.shape[2], | |
], | |
starting_frame=[0], | |
rate=30.0, | |
), | |
source_video_frame_index=training_frame_video_index, | |
) | |
training_frame_list.append(training_frame) | |
training_frames = TrainingFrames(training_frames=training_frame_list) | |
pose_training = PoseTraining( | |
training_frames=training_frames, | |
source_videos=videos_to_source_videos(labels.videos), | |
) | |
return pose_training |
sleap_io/io/nwb.py
Outdated
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs): | ||
"""Write a SLEAP package to an NWB file. | ||
|
||
Args: | ||
filename: The path to the SLEAP package. | ||
labels: The SLEAP Labels object. | ||
""" | ||
assert filename.endswith(".pkg.slp") | ||
|
||
path = filename.split(".slp")[0] | ||
save_path = Path(path + ".nwb_images") | ||
img_paths = [] | ||
for i, labeled_frame in enumerate(labels.labeled_frames): | ||
img_path = save_path / f"frame_{i}.png" | ||
imwrite(img_path, labeled_frame.image) | ||
img_paths.append(img_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to sleap_pkg_to_nwb
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
+ if not filename.endswith(".pkg.slp"):
+ raise ValueError("Invalid filename: Expected a .pkg.slp file.")
Committable suggestion was skipped due to low confidence.
sleap_io/io/nwb.py
Outdated
def append_nwb_training( | ||
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict] | ||
) -> NWBFile: | ||
"""Append a PoseTraining object to an existing NWB data file. | ||
|
||
Args: | ||
pose_training: A PoseTraining object. | ||
nwbfile_path: The path to the NWB file. | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
pose_training = labels_to_pose_training(labels) | ||
provenance = labels.provenance | ||
raise NotImplementedError | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement or document the append_nwb_training
function.
Currently, the function raises NotImplementedError
. This needs to be addressed either by implementing the function or by documenting it as a known limitation.
- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training( | |
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict] | |
) -> NWBFile: | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_training = labels_to_pose_training(labels) | |
provenance = labels.provenance | |
raise NotImplementedError | |
def append_nwb_training( | |
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict] | |
) -> NWBFile: | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_training = labels_to_pose_training(labels) | |
provenance = labels.provenance | |
# TODO: Implement or document this function |
Tools
Ruff
505-505: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
506-506: Local variable
provenance
is assigned to but never usedRemove assignment to unused variable
provenance
(F841)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Outside diff range, codebase verification and nitpick comments (3)
sleap_io/io/nwb.py (3)
Line range hint
358-415
:
Review the functionwrite_nwb
.This function writes data from a
PoseTraining
object to an NWB file. The function setup and the handling of optional metadata are done correctly. The use ofappend_nwb_data
to integrate the training data into the NWB file is appropriate.However, the function lacks error handling for potential issues during the file writing process. Consider adding error handling mechanisms to improve robustness.
+ try: nwbfile = NWBFile(**nwb_file_kwargs) nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata) with NWBHDF5IO(str(nwbfile_path), "w") as io: io.write(nwbfile) + except Exception as e: + # Handle potential file writing errors + print(f"Error writing NWB file: {e}")
Line range hint
522-551
:
Add error handling toappend_nwb
.The function lacks error handling for potential issues during the file writing process. Consider adding error handling mechanisms to improve robustness.
+ try: if as_training: with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io: nwb_file = io.read() nwb_file = append_nwb_training( labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata ) io.write(nwb_file) else: with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io: nwb_file = io.read() nwb_file = append_nwb_data( labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata ) io.write(nwb_file) + except Exception as e: + # Handle potential file writing errors + print(f"Error appending to NWB file: {e}")
Line range hint
652-694
:
Add error handling and type annotations tobuild_track_pose_estimation_list
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_track_pose_estimation_list( track_data_df: pd.DataFrame, timestamps: ArrayLike # type: ignore[return] ) -> List[PoseEstimationSeries]: + if track_data_df is None or timestamps is None: + raise ValueError("Invalid input: track_data_df and timestamps cannot be None.")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
511-511: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
518-518: Local variable
labels_data_df
is assigned to but never usedRemove assignment to unused variable
labels_data_df
(F841)
Additional comments not posted (1)
sleap_io/io/nwb.py (1)
49-72
: Incomplete functionality inpose_training_to_labels
.The function contains a TODO comment indicating incomplete functionality for determining
frame_idx
fromtraining_frame
.- frame_idx = training_frame # TODO + frame_idx = training_frame.frame_index # Assuming `frame_index` is an attribute of `TrainingFrame`Do you want me to assist in completing this functionality or open a GitHub issue to track this task?
Likely invalid or redundant comment.
sleap_io/io/nwb.py
Outdated
def append_nwb_training( | ||
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict] | ||
) -> NWBFile: | ||
"""Append a PoseTraining object to an existing NWB data file. | ||
|
||
Args: | ||
pose_training: A PoseTraining object. | ||
nwbfile_path: The path to the NWB file. | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
pose_estimation_metadata = pose_estimation_metadata or dict() | ||
provenance = labels.provenance | ||
default_metadata = dict(scorer=str(provenance)) | ||
sleap_version = provenance.get("sleap_version", None) | ||
default_metadata["source_software_version"] = sleap_version | ||
pose_training = labels_to_pose_training(labels) | ||
|
||
for lf in labels.labeled_frames: | ||
if lf.has_predicted_instances: | ||
labels_data_df = convert_predictions_to_dataframe(labels) | ||
break | ||
else: | ||
labels_data_df = pd.DataFrame() | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement or document the append_nwb_training
function.
Currently, the function raises NotImplementedError
. This needs to be addressed either by implementing the function or by documenting it as a known limitation.
- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training( | |
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict] | |
) -> NWBFile: | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
pose_training = labels_to_pose_training(labels) | |
for lf in labels.labeled_frames: | |
if lf.has_predicted_instances: | |
labels_data_df = convert_predictions_to_dataframe(labels) | |
break | |
else: | |
labels_data_df = pd.DataFrame() | |
raise NotImplementedError | |
def append_nwb_training( | |
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict] | |
) -> NWBFile: | |
"""Append a PoseTraining object to an existing NWB data file. | |
Args: | |
pose_training: A PoseTraining object. | |
nwbfile_path: The path to the NWB file. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
pose_training = labels_to_pose_training(labels) | |
for lf in labels.labeled_frames: | |
if lf.has_predicted_instances: | |
labels_data_df = convert_predictions_to_dataframe(labels) | |
break | |
else: | |
labels_data_df = pd.DataFrame() | |
# TODO: Implement or document this function |
Tools
Ruff
511-511: Local variable
pose_training
is assigned to but never usedRemove assignment to unused variable
pose_training
(F841)
518-518: Local variable
labels_data_df
is assigned to but never usedRemove assignment to unused variable
labels_data_df
(F841)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Outside diff range, codebase verification and nitpick comments (6)
sleap_io/io/nwb.py (6)
Line range hint
361-414
: Add error handling and type annotations towrite_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def write_nwb( labels: Labels, nwbfile_path: str, nwb_file_kwargs: Optional[dict] = None, pose_estimation_metadata: Optional[dict] = None, as_training: Optional[bool] = None, ): + if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.")
Line range hint
446-494
: Add error handling and type annotations toappend_nwb_data
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb_data( labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None ) -> NWBFile: + if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected Labels and NWBFile objects.")
Line range hint
543-571
: Add error handling and type annotations toappend_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb( labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None, as_training: Optional[bool] = None, ): + if not isinstance(labels, Labels) or not isinstance(filename, str): + raise ValueError("Invalid input: Expected Labels object and filename string.")
Line range hint
573-591
: Add type annotations toget_processing_module_for_video
.The function lacks type annotations for better readability and maintainability.
def get_processing_module_for_video( processing_module_name: str, nwbfile: NWBFile ) -> ProcessingModule:
Line range hint
593-671
: Add error handling and type annotations tobuild_pose_estimation_container_for_track
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_pose_estimation_container_for_track( labels_data_df: pd.DataFrame, labels: Labels, track_name: str, video: Video, pose_estimation_metadata: dict, ) -> PoseEstimation: + if not isinstance(labels_data_df, pd.DataFrame) or not isinstance(labels, Labels) or not isinstance(track_name, str) or not isinstance(video, Video) or not isinstance(pose_estimation_metadata, dict): + raise ValueError("Invalid input: Expected appropriate data types for inputs.")
Line range hint
673-721
: Add error handling and type annotations tobuild_track_pose_estimation_list
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_track_pose_estimation_list( track_data_df: pd.DataFrame, timestamps: ArrayLike ) -> List[PoseEstimationSeries]: + if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, np.ndarray): + raise ValueError("Invalid input: Expected a pandas DataFrame and a numpy array.")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
529-529: Local variable
camera
is assigned to but never usedRemove assignment to unused variable
camera
(F841)
538-538: Local variable
labels_data_df
is assigned to but never usedRemove assignment to unused variable
labels_data_df
(F841)
Additional comments not posted (1)
sleap_io/io/nwb.py (1)
51-74
: Add error handling and type annotations topose_training_to_labels
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: + if not isinstance(pose_training, PoseTraining): + raise ValueError("Invalid input: Expected a PoseTraining object.")Likely invalid or redundant comment.
sleap_io/io/nwb.py
Outdated
def append_nwb_training( | ||
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] | ||
) -> NWBFile: | ||
"""Append training data from a Labels object to an in-memory NWB file. | ||
|
||
Args: | ||
labels: A general labels object. | ||
nwbfile: An in-memory NWB file. | ||
pose_estimation_metadata: Metadata for pose estimation. | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
pose_estimation_metadata = pose_estimation_metadata or dict() | ||
provenance = labels.provenance | ||
default_metadata = dict(scorer=str(provenance)) | ||
sleap_version = provenance.get("sleap_version", None) | ||
default_metadata["source_software_version"] = sleap_version | ||
|
||
subject = Subject(subject_id="No specified id", species="No specified species") | ||
nwbfile.subject = subject | ||
pose_training = labels_to_pose_training(labels) | ||
|
||
behavior_pm = nwbfile.create_processing_module( | ||
name="behavior", | ||
description="Behavioral data", | ||
) | ||
behavior_pm.add(pose_training) | ||
|
||
skeletons_list = [slp_skeleton_to_nwb(skeleton) for skeleton in labels.skeletons] | ||
skeletons = Skeletons(skeletons=skeletons_list) | ||
behavior_pm.add(skeletons) | ||
|
||
camera = nwbfile.create_device(name="camera", | ||
description="Camera used to record the video", | ||
manufacturer="No specified manufacturer") | ||
|
||
for lf in labels.labeled_frames: | ||
if lf.has_predicted_instances: | ||
labels_data_df = convert_predictions_to_dataframe(labels) | ||
break | ||
else: | ||
labels_data_df = pd.DataFrame() | ||
return nwbfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to append_nwb_training
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb_training(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
+ if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
+ raise ValueError("Invalid input: Expected Labels and NWBFile objects.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training( | |
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] | |
) -> NWBFile: | |
"""Append training data from a Labels object to an in-memory NWB file. | |
Args: | |
labels: A general labels object. | |
nwbfile: An in-memory NWB file. | |
pose_estimation_metadata: Metadata for pose estimation. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
subject = Subject(subject_id="No specified id", species="No specified species") | |
nwbfile.subject = subject | |
pose_training = labels_to_pose_training(labels) | |
behavior_pm = nwbfile.create_processing_module( | |
name="behavior", | |
description="Behavioral data", | |
) | |
behavior_pm.add(pose_training) | |
skeletons_list = [slp_skeleton_to_nwb(skeleton) for skeleton in labels.skeletons] | |
skeletons = Skeletons(skeletons=skeletons_list) | |
behavior_pm.add(skeletons) | |
camera = nwbfile.create_device(name="camera", | |
description="Camera used to record the video", | |
manufacturer="No specified manufacturer") | |
for lf in labels.labeled_frames: | |
if lf.has_predicted_instances: | |
labels_data_df = convert_predictions_to_dataframe(labels) | |
break | |
else: | |
labels_data_df = pd.DataFrame() | |
return nwbfile | |
def append_nwb_training( | |
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] | |
) -> NWBFile: | |
if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile): | |
raise ValueError("Invalid input: Expected Labels and NWBFile objects.") | |
"""Append training data from a Labels object to an in-memory NWB file. | |
Args: | |
labels: A general labels object. | |
nwbfile: An in-memory NWB file. | |
pose_estimation_metadata: Metadata for pose estimation. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
subject = Subject(subject_id="No specified id", species="No specified species") | |
nwbfile.subject = subject | |
pose_training = labels_to_pose_training(labels) | |
behavior_pm = nwbfile.create_processing_module( | |
name="behavior", | |
description="Behavioral data", | |
) | |
behavior_pm.add(pose_training) | |
skeletons_list = [slp_skeleton_to_nwb(skeleton) for skeleton in labels.skeletons] | |
skeletons = Skeletons(skeletons=skeletons_list) | |
behavior_pm.add(skeletons) | |
camera = nwbfile.create_device(name="camera", | |
description="Camera used to record the video", | |
manufacturer="No specified manufacturer") | |
for lf in labels.labeled_frames: | |
if lf.has_predicted_instances: | |
labels_data_df = convert_predictions_to_dataframe(labels) | |
break | |
else: | |
labels_data_df = pd.DataFrame() | |
return nwbfile |
Tools
Ruff
529-529: Local variable
camera
is assigned to but never usedRemove assignment to unused variable
camera
(F841)
538-538: Local variable
labels_data_df
is assigned to but never usedRemove assignment to unused variable
labels_data_df
(F841)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Outside diff range, codebase verification and nitpick comments (5)
sleap_io/io/nwb.py (5)
Line range hint
361-414
:
Add error handling and type annotations towrite_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def write_nwb( labels: Labels, nwbfile_path: str, nwb_file_kwargs: Optional[dict] = None, pose_estimation_metadata: Optional[dict] = None, as_training: Optional[bool] = None, ): + if not isinstance(labels, Labels) or not isinstance(nwbfile_path, str): + raise ValueError("Invalid input: Expected Labels object and a string path.")
Line range hint
417-493
:
Add error handling and type annotations toappend_nwb_data
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb_data( labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None ) -> NWBFile: + if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected Labels and NWBFile objects.")
Line range hint
576-604
:
Add error handling and type annotations toappend_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb( labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None, as_training: Optional[bool] = None, ): + if not isinstance(labels, Labels) or not isinstance(filename, str): + raise ValueError("Invalid input: Expected Labels object and a string filename.")
Line range hint
606-690
:
Add error handling and type annotations tobuild_pose_estimation_container_for_track
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_pose_estimation_container_for_track( labels_data_df: pd.DataFrame, labels: Labels, track_name: str, video: Video, pose_estimation_metadata: dict, ) -> PoseEstimation: + if not isinstance(labels_data_df, pd.DataFrame) or not isinstance(labels, Labels) or not isinstance(track_name, str) or not isinstance(video, Video): + raise ValueError("Invalid input: Expected DataFrame, Labels object, string track name, and Video object.")
Line range hint
706-749
:
Add error handling and type annotations tobuild_track_pose_estimation_list
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def build_track_pose_estimation_list( track_data_df: pd.DataFrame, timestamps: ArrayLike ) -> List[PoseEstimationSeries]: + if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, np.ndarray): + raise ValueError("Invalid input: Expected DataFrame and ndarray.")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
551-551: Local variable
pose_estimation
is assigned to but never usedRemove assignment to unused variable
pose_estimation
(F841)
571-571: Local variable
labels_data_df
is assigned to but never usedRemove assignment to unused variable
labels_data_df
(F841)
Additional comments not posted (7)
sleap_io/io/nwb.py (7)
77-93
: Add error handling and type annotations tonwb_skeleton_to_sleap
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: + if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'): + raise ValueError("Invalid skeleton: Missing required data.")
95-147
: Break downlabels_to_pose_training
and add error handling.The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: + if labels is None: + raise ValueError("Labels object cannot be None.")
150-172
: Add error handling and type annotations toslp_skeleton_to_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: + if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'): + raise ValueError("Invalid skeleton: Missing required data.")
175-194
: Add error handling and type annotations toinstance_to_skeleton_instance
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance: + if instance is None or not hasattr(instance, 'skeleton') or not hasattr(instance, 'points'): + raise ValueError("Invalid instance: Missing required data.")
197-219
: Add error handling and type annotations tovideos_to_source_videos
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: + if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos): + raise ValueError("Invalid input: Expected a list of Video objects.")
222-238
: Add error handling and type annotations tosleap_pkg_to_nwb
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs): + if not filename.endswith(".pkg.slp"): + raise ValueError("Invalid filename: Expected a .pkg.slp file.")
496-572
: Add error handling and type annotations toappend_nwb_training
.The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def append_nwb_training( labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] ) -> NWBFile: + if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected Labels and NWBFile objects.")Tools
Ruff
551-551: Local variable
pose_estimation
is assigned to but never usedRemove assignment to unused variable
pose_estimation
(F841)
571-571: Local variable
labels_data_df
is assigned to but never usedRemove assignment to unused variable
labels_data_df
(F841)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (15 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
Additional comments not posted (13)
sleap_io/io/nwb.py (13)
61-69
: Add error handling for invalid inputs.The function lacks error handling for invalid
pose_training
inputs, which can improve robustness and maintainability.
75-78
: Handle multiple external files.The current implementation only supports single-file external videos. Consider implementing logic to handle multiple external files.
82-92
: Optimize instance processing.The loop processing the instances can be optimized by using list comprehensions.
99-115
: Add error handling for invalid inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.
117-166
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
168-194
: Add error handling for invalid inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.
197-220
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
215-216
: Complete the TODO comment for instance tracking.The function contains a TODO comment indicating incomplete functionality for tracking the number of instances.
223-246
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
242-242
: Complete the TODO comment forvideo.backend.fps
.The function contains a TODO comment indicating the need to change the rate to
video.backend.fps
when available.
314-314
: Complete the TODO comment forvideo.backend.fps
.The function contains a TODO comment indicating the need to change the rate to
video.backend.fps
when available.
618-719
: Add error handling for invalid inputs.The function lacks error handling for invalid
labels
,nwbfile
,pose_estimation_metadata
,frame_inds
, andframe_path
inputs, which can improve robustness and maintainability.
857-857
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Outside diff range, codebase verification and nitpick comments (5)
tests/io/test_main.py (1)
40-42
: Add assertions to verify the saved NWB file.The
test_nwb_training
function currently lacks assertions to verify the correctness of the operation. Consider adding checks to ensure that the labels are saved and loaded correctly.loaded_labels = load_nwb(tmp_path / "test_nwb.nwb") assert isinstance(loaded_labels, Labels) assert len(loaded_labels) == len(labels)sleap_io/io/nwb.py (4)
17-20
: Usecontextlib.suppress
for import suppression.Replace the
try
-except
-pass
block withcontextlib.suppress
to improve readability.-try: - import cv2 -except ImportError: - pass +from contextlib import suppress +with suppress(ImportError): + import cv2Tools
Ruff
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
137-139
: Optimize instance processing.The loop processing the instances can be optimized by using list comprehensions.
- for instance, skeleton in zip(labeled_frame.instances, skeletons_list): - skeleton_instance = instance_to_skeleton_instance(instance, skeleton) - skeleton_instances_list.append(skeleton_instance) + skeleton_instances_list.extend( + instance_to_skeleton_instance(instance, skeleton) + for instance, skeleton in zip(labeled_frame.instances, skeletons_list) + )
183-187
: Optimize edge processing logic.The current implementation iterates over skeleton edges multiple times. Consider optimizing the logic to reduce complexity.
- skeleton_edges = dict(enumerate(skeleton.nodes)) - for i, source in skeleton_edges.items(): - for destination in list(skeleton_edges.values())[i:]: - if Edge(source, destination) in skeleton.edges: - nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) + node_indices = {node: idx for idx, node in enumerate(skeleton.nodes)} + nwb_edges = [ + [node_indices[edge.source], node_indices[edge.destination]] + for edge in skeleton.edges + ]
711-712
: Remove unused variablesprocessing_module_name
andnwb_processing_module
.These variables are assigned but never used, which can lead to confusion and clutter.
- processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" - nwb_processing_module = get_processing_module_for_video( - processing_module_name, nwbfile - )
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- sleap_io/io/nwb.py (17 hunks)
- tests/io/test_main.py (1 hunks)
Additional context used
Ruff
tests/io/test_main.py
29-29: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
30-30: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
36-36: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
Additional comments not posted (7)
tests/io/test_main.py (1)
25-30
: Useisinstance()
for type checks.Replace direct type comparisons with
isinstance()
for better readability and performance.- assert type(loaded_labels) == Labels - assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels + assert isinstance(loaded_labels, Labels) + assert isinstance(load_file(tmp_path / "test_nwb.nwb"), Labels)Tools
Ruff
29-29: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
30-30: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
sleap_io/io/nwb.py (6)
Line range hint
778-809
: LGTM!The function correctly appends a SLEAP
Labels
object to an existing NWB data file.
197-220
: Add error handling for invalid inputs.The function lacks error handling for invalid
instance
andskeleton
inputs, which can improve robustness and maintainability.+ if not isinstance(instance, Instance): + raise ValueError("Invalid input: Expected an Instance object.") + if not isinstance(skeleton, NWBSkeleton): + raise ValueError("Invalid input: Expected an NWBSkeleton object.")Likely invalid or redundant comment.
915-915
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, (np.ndarray, list)): + raise ValueError("Invalid input: Expected a pandas DataFrame and an array-like object.")Likely invalid or redundant comment.
676-695
: Add error handling for invalid inputs.The function lacks error handling for invalid
labels
,nwbfile
,pose_estimation_metadata
,frame_inds
, andframe_path
inputs, which can improve robustness and maintainability.+ if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.") + if not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected an NWBFile object.") + if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict): + raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.") + if frame_inds is not None and not isinstance(frame_inds, list): + raise ValueError("Invalid input: Expected a list for frame_inds.") + if frame_path is not None and not isinstance(frame_path, str): + raise ValueError("Invalid input: Expected a string for frame_path.")Likely invalid or redundant comment.
75-78
: Handle multiple external files.The current implementation only supports single-file external videos. Consider implementing logic to handle multiple external files.
- if source_video.format == "external" and len(source_video.external_file) == 1: - video = Video(source_video.external_file[0]) - else: - raise NotImplementedError("Only single-file external videos are supported.") + if source_video.format == "external": + if len(source_video.external_file) == 1: + video = Video(source_video.external_file[0]) + else: + # Implement logic for handling multiple external files + raise NotImplementedError("Multiple external files are not yet supported.")Likely invalid or redundant comment.
117-132
: Add error handling for invalid inputs.The function lacks error handling for invalid
labels
,skeletons_list
, andvideo_info
inputs, which can improve robustness and maintainability.+ if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.") + if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list): + raise ValueError("Invalid input: Expected a list of NWBSkeleton objects.") + if not isinstance(video_info, tuple) or len(video_info) != 3: + raise ValueError("Invalid input: Expected a tuple with three elements (dict, Video, ImageSeries).")Likely invalid or redundant comment.
sleap_io/io/nwb.py
Outdated
# TODO add a counter in the loop to track the number of instances | ||
# instead of using id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complete the TODO comment for instance tracking.
The function contains a TODO comment indicating incomplete functionality for tracking the number of instances.
Do you want me to assist in completing this functionality or open a GitHub issue to track this task?
def slp_skeleton_to_nwb( | ||
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None | ||
) -> NWBSkeleton: # type: ignore[return] | ||
"""Converts SLEAP skeleton to NWB skeleton. | ||
|
||
Args: | ||
skeleton: A SLEAP skeleton. | ||
subject: An NWB subject. | ||
|
||
Returns: | ||
An NWB skeleton. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for invalid inputs.
The function lacks error handling for invalid skeleton
inputs, which can improve robustness and maintainability.
+ if not isinstance(skeleton, SLEAPSkeleton):
+ raise ValueError("Invalid input: Expected a SLEAPSkeleton object.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def slp_skeleton_to_nwb( | |
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None | |
) -> NWBSkeleton: # type: ignore[return] | |
"""Converts SLEAP skeleton to NWB skeleton. | |
Args: | |
skeleton: A SLEAP skeleton. | |
subject: An NWB subject. | |
Returns: | |
An NWB skeleton. | |
""" | |
def slp_skeleton_to_nwb( | |
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None | |
) -> NWBSkeleton: # type: ignore[return] | |
"""Converts SLEAP skeleton to NWB skeleton. | |
Args: | |
skeleton: A SLEAP skeleton. | |
subject: An NWB subject. | |
Returns: | |
An NWB skeleton. | |
""" | |
if not isinstance(skeleton, SLEAPSkeleton): | |
raise ValueError("Invalid input: Expected a SLEAPSkeleton object.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Outside diff range, codebase verification and nitpick comments (2)
tests/io/test_main.py (1)
34-36
: Add assertions to verify test outcomes.The
test_nwb_training
function lacks assertions to verify that thesave_nwb
operation has the expected effect. Consider adding checks similar to those intest_nwb
to ensure the saved and loaded data match the expected structure and content.sleap_io/io/nwb.py (1)
Line range hint
787-818
: Add error handling for invalid inputs.Consider adding error handling for invalid inputs to improve robustness and maintainability.
+ if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.") + if not isinstance(filename, str): + raise ValueError("Invalid input: Expected a string for filename.") + if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict): + raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.") + if frame_inds is not None and not isinstance(frame_inds, list): + raise ValueError("Invalid input: Expected a list for frame_inds.") + if frame_path is not None and not isinstance(frame_path, str): + raise ValueError("Invalid input: Expected a string for frame_path.")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- sleap_io/io/nwb.py (17 hunks)
- tests/io/test_main.py (1 hunks)
Additional context used
Ruff
tests/io/test_main.py
29-29: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
30-30: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
Additional comments not posted (11)
tests/io/test_main.py (1)
25-32
: Useisinstance()
for type checks.Replace direct type comparisons with
isinstance()
for better readability and performance.- assert type(loaded_labels) == Labels - assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels + assert isinstance(loaded_labels, Labels) + assert isinstance(load_file(tmp_path / "test_nwb.nwb"), Labels)Tools
Ruff
29-29: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
30-30: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
sleap_io/io/nwb.py (10)
61-96
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
75-78
: Handle multiple external files.The current implementation only supports single-file external videos. Consider implementing logic to handle multiple external files.
99-115
: Add error handling for invalid inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.
117-166
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
168-194
: Add error handling for invalid inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.
197-220
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
215-216
: Complete the TODO comment for instance tracking.The function contains a TODO comment indicating incomplete functionality for tracking the number of instances.
223-246
: Add error handling for invalid inputs.The function lacks error handling for invalid
videos
inputs, which can improve robustness and maintainability.
685-786
: Add error handling for invalid inputs.The function lacks error handling for invalid
labels
,nwbfile
,pose_estimation_metadata
,frame_inds
, andframe_path
inputs, which can improve robustness and maintainability.
924-924
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #104 +/- ##
==========================================
- Coverage 96.01% 94.22% -1.80%
==========================================
Files 17 17
Lines 2035 2215 +180
==========================================
+ Hits 1954 2087 +133
- Misses 81 128 +47 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Outside diff range, codebase verification and nitpick comments (1)
sleap_io/io/nwb.py (1)
17-20
: Usecontextlib.suppress
for import suppression.Replace the
try
-except
-pass
block withcontextlib.suppress
to improve readability.-try: - import cv2 -except ImportError: - pass +from contextlib import suppress +with suppress(ImportError): + import cv2Tools
Ruff
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
Additional comments not posted (6)
sleap_io/io/nwb.py (6)
Line range hint
753-784
: LGTM!The function is well-implemented and handles the
as_training
parameter appropriately.
Line range hint
818-887
: LGTM!The function is well-implemented and handles metadata and timestamps appropriately.
197-220
: Add error handling and complete the TODO comment for instance tracking.The function lacks error handling for invalid inputs. The TODO comment indicates incomplete functionality for tracking the number of instances.
+ if not isinstance(instance, Instance): + raise ValueError("Invalid input: Expected an Instance object.") + if not isinstance(skeleton, NWBSkeleton): + raise ValueError("Invalid input: Expected an NWBSkeleton object.")Do you want me to assist in completing the functionality for tracking the number of instances or open a GitHub issue to track this task?
Likely invalid or redundant comment.
890-890
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, (np.ndarray, list)): + raise ValueError("Invalid input: Expected a pandas DataFrame and an array-like object.")Likely invalid or redundant comment.
168-194
: Optimize edge processing logic.The current implementation iterates over skeleton edges multiple times. Consider optimizing the logic to reduce complexity.
- skeleton_edges = dict(enumerate(skeleton.nodes)) - for i, source in skeleton_edges.items(): - for destination in list(skeleton_edges.values())[i:]: - if Edge(source, destination) in skeleton.edges: - nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) + node_indices = {node: idx for idx, node in enumerate(skeleton.nodes)} + nwb_edges = [ + [node_indices[edge.source], node_indices[edge.destination]] + for edge in skeleton.edges + ]Likely invalid or redundant comment.
99-115
: Add error handling for invalid skeleton inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.+ if not isinstance(skeleton, NWBSkeleton): + raise ValueError("Invalid input: Expected an NWBSkeleton object.")Likely invalid or redundant comment.
sleap_io/io/nwb.py
Outdated
def labels_to_pose_training( | ||
labels: Labels, | ||
skeletons_list: list[NWBSkeleton], # type: ignore[return] | ||
video_info: tuple[dict[int, str], Video, ImageSeries], | ||
) -> PoseTraining: # type: ignore[return] | ||
"""Creates an NWB PoseTraining object from a Labels object. | ||
|
||
Args: | ||
labels: A Labels object. | ||
skeletons_list: A list of NWB skeletons. | ||
video_info: A tuple containing a dictionary mapping frame indices to file paths, | ||
the video, and the `ImageSeries`. | ||
|
||
Returns: | ||
A PoseTraining object. | ||
""" | ||
training_frame_list = [] | ||
skeleton_instances_list = [] | ||
source_video_list = [] | ||
for i, labeled_frame in enumerate(labels.labeled_frames): | ||
for instance, skeleton in zip(labeled_frame.instances, skeletons_list): | ||
skeleton_instance = instance_to_skeleton_instance(instance, skeleton) | ||
skeleton_instances_list.append(skeleton_instance) | ||
|
||
training_frame_skeleton_instances = SkeletonInstances( | ||
skeleton_instances=skeleton_instances_list | ||
) | ||
training_frame_video_index = labeled_frame.frame_idx | ||
|
||
image_series = video_info[2] | ||
source_video = image_series | ||
if source_video not in source_video_list: | ||
source_video_list.append(source_video) | ||
training_frame = TrainingFrame( | ||
name=f"training_frame_{i}", | ||
annotator="N/A", | ||
skeleton_instances=training_frame_skeleton_instances, | ||
source_video=source_video, | ||
source_video_frame_index=training_frame_video_index, | ||
) | ||
training_frame_list.append(training_frame) | ||
|
||
training_frames = TrainingFrames(training_frames=training_frame_list) | ||
source_videos = SourceVideos(image_series=source_video_list) | ||
pose_training = PoseTraining( | ||
training_frames=training_frames, | ||
source_videos=source_videos, | ||
) | ||
return pose_training | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and handle multiple external files.
The function lacks error handling for invalid inputs. Consider implementing logic to handle multiple external files.
+ if not isinstance(labels, Labels):
+ raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list):
+ raise ValueError("Invalid input: Expected a list of NWBSkeleton objects.")
+ if not isinstance(video_info, tuple) or len(video_info) != 3:
+ raise ValueError("Invalid input: Expected a tuple with three elements (dict, Video, ImageSeries).")
- if source_video.format == "external" and len(source_video.external_file) == 1:
- video = Video(source_video.external_file[0])
- else:
- raise NotImplementedError("Only single-file external videos are supported.")
+ if source_video.format == "external":
+ if len(source_video.external_file) == 1:
+ video = Video(source_video.external_file[0])
+ else:
+ # Implement logic for handling multiple external files
+ raise NotImplementedError("Multiple external files are not yet supported.")
sleap_io/io/nwb.py
Outdated
nwbfile: NWBFile, | ||
pose_estimation_metadata: Optional[dict] = None, | ||
frame_inds: Optional[list[int]] = None, | ||
frame_path: Optional[str] = None, | ||
) -> NWBFile: | ||
"""Append training data from a Labels object to an in-memory NWB file. | ||
|
||
Args: | ||
labels: A general labels object. | ||
nwbfile: An in-memory NWB file. | ||
pose_estimation_metadata: Metadata for pose estimation. | ||
frame_inds: The indices of the frames to write. If None, all frames are written. | ||
frame_path: The path to save the frames. If None, the path is the video | ||
filename without the extension. | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
pose_estimation_metadata = pose_estimation_metadata or dict() | ||
provenance = labels.provenance | ||
default_metadata = dict(scorer=str(provenance)) | ||
sleap_version = provenance.get("sleap_version", None) | ||
default_metadata["source_software_version"] = sleap_version | ||
|
||
subject = Subject(subject_id="No specified id", species="No specified species") | ||
nwbfile.subject = subject | ||
|
||
for i, video in enumerate(labels.videos): | ||
video_path = ( | ||
Path(video.filename) | ||
if isinstance(video.filename, str) | ||
else video.filename[i] | ||
) | ||
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" | ||
nwb_processing_module = get_processing_module_for_video( | ||
processing_module_name, nwbfile | ||
) | ||
default_metadata["original_videos"] = [f"{video.filename}"] | ||
default_metadata["labeled_videos"] = [f"{video.filename}"] | ||
default_metadata.update(pose_estimation_metadata) | ||
|
||
skeletons_list = [ | ||
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons | ||
] | ||
skeletons = Skeletons(skeletons=skeletons_list) | ||
nwb_processing_module.add(skeletons) | ||
video_info = write_video_to_path( | ||
labels.videos[0], frame_inds, frame_path=frame_path | ||
) | ||
pose_training = labels_to_pose_training(labels, skeletons_list, video_info) | ||
nwb_processing_module.add(pose_training) | ||
|
||
confidence_definition = "Softmax output of the deep neural network" | ||
reference_frame = ( | ||
"The coordinates are in (x, y) relative to the top-left of the image. " | ||
"Coordinates refer to the midpoint of the pixel. " | ||
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas " | ||
"the top-left corner of that same pixel is at (-0.5, -0.5)." | ||
) | ||
pose_estimation_series_list = [] | ||
for node in skeletons_list[0].nodes: | ||
pose_estimation_series = PoseEstimationSeries( | ||
name=node, | ||
description=f"Marker placed on {node}", | ||
data=np.random.rand(100, 2), | ||
unit="pixels", | ||
reference_frame=reference_frame, | ||
timestamps=np.linspace(0, 10, num=100), | ||
confidence=np.random.rand(100), | ||
confidence_definition=confidence_definition, | ||
) | ||
pose_estimation_series_list.append(pose_estimation_series) | ||
|
||
camera = nwbfile.create_device( | ||
name=f"camera {i}", | ||
description=f"Camera used to record video {i}", | ||
manufacturer="No specified manufacturer", | ||
) | ||
try: | ||
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]]) | ||
except AttributeError: | ||
dimensions = np.array([[400, 400]]) | ||
|
||
pose_estimation = PoseEstimation( | ||
name="pose_estimation", | ||
pose_estimation_series=pose_estimation_series_list, | ||
description="Estimated positions of the nodes in the video", | ||
original_videos=[video.filename for video in labels.videos], | ||
labeled_videos=[video.filename for video in labels.videos], | ||
dimensions=dimensions, | ||
devices=[camera], | ||
scorer="No specified scorer", | ||
source_software="SLEAP", | ||
source_software_version=sleap_version, | ||
skeleton=skeletons_list[0], | ||
) | ||
nwb_processing_module.add(pose_estimation) | ||
|
||
return nwbfile | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for invalid inputs.
The function lacks error handling for invalid labels
, nwbfile
, pose_estimation_metadata
, frame_inds
, and frame_path
inputs, which can improve robustness and maintainability.
+ if not isinstance(labels, Labels):
+ raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(nwbfile, NWBFile):
+ raise ValueError("Invalid input: Expected an NWBFile object.")
+ if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+ raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+ if frame_inds is not None and not isinstance(frame_inds, list):
+ raise ValueError("Invalid input: Expected a list for frame_inds.")
+ if frame_path is not None and not isinstance(frame_path, str):
+ raise ValueError("Invalid input: Expected a string for frame_path.")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training( | |
labels: Labels, | |
nwbfile: NWBFile, | |
pose_estimation_metadata: Optional[dict] = None, | |
frame_inds: Optional[list[int]] = None, | |
frame_path: Optional[str] = None, | |
) -> NWBFile: | |
"""Append training data from a Labels object to an in-memory NWB file. | |
Args: | |
labels: A general labels object. | |
nwbfile: An in-memory NWB file. | |
pose_estimation_metadata: Metadata for pose estimation. | |
frame_inds: The indices of the frames to write. If None, all frames are written. | |
frame_path: The path to save the frames. If None, the path is the video | |
filename without the extension. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
subject = Subject(subject_id="No specified id", species="No specified species") | |
nwbfile.subject = subject | |
for i, video in enumerate(labels.videos): | |
video_path = ( | |
Path(video.filename) | |
if isinstance(video.filename, str) | |
else video.filename[i] | |
) | |
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" | |
nwb_processing_module = get_processing_module_for_video( | |
processing_module_name, nwbfile | |
) | |
default_metadata["original_videos"] = [f"{video.filename}"] | |
default_metadata["labeled_videos"] = [f"{video.filename}"] | |
default_metadata.update(pose_estimation_metadata) | |
skeletons_list = [ | |
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons | |
] | |
skeletons = Skeletons(skeletons=skeletons_list) | |
nwb_processing_module.add(skeletons) | |
video_info = write_video_to_path( | |
labels.videos[0], frame_inds, frame_path=frame_path | |
) | |
pose_training = labels_to_pose_training(labels, skeletons_list, video_info) | |
nwb_processing_module.add(pose_training) | |
confidence_definition = "Softmax output of the deep neural network" | |
reference_frame = ( | |
"The coordinates are in (x, y) relative to the top-left of the image. " | |
"Coordinates refer to the midpoint of the pixel. " | |
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas " | |
"the top-left corner of that same pixel is at (-0.5, -0.5)." | |
) | |
pose_estimation_series_list = [] | |
for node in skeletons_list[0].nodes: | |
pose_estimation_series = PoseEstimationSeries( | |
name=node, | |
description=f"Marker placed on {node}", | |
data=np.random.rand(100, 2), | |
unit="pixels", | |
reference_frame=reference_frame, | |
timestamps=np.linspace(0, 10, num=100), | |
confidence=np.random.rand(100), | |
confidence_definition=confidence_definition, | |
) | |
pose_estimation_series_list.append(pose_estimation_series) | |
camera = nwbfile.create_device( | |
name=f"camera {i}", | |
description=f"Camera used to record video {i}", | |
manufacturer="No specified manufacturer", | |
) | |
try: | |
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]]) | |
except AttributeError: | |
dimensions = np.array([[400, 400]]) | |
pose_estimation = PoseEstimation( | |
name="pose_estimation", | |
pose_estimation_series=pose_estimation_series_list, | |
description="Estimated positions of the nodes in the video", | |
original_videos=[video.filename for video in labels.videos], | |
labeled_videos=[video.filename for video in labels.videos], | |
dimensions=dimensions, | |
devices=[camera], | |
scorer="No specified scorer", | |
source_software="SLEAP", | |
source_software_version=sleap_version, | |
skeleton=skeletons_list[0], | |
) | |
nwb_processing_module.add(pose_estimation) | |
return nwbfile | |
def append_nwb_training( | |
labels: Labels, | |
nwbfile: NWBFile, | |
pose_estimation_metadata: Optional[dict] = None, | |
frame_inds: Optional[list[int]] = None, | |
frame_path: Optional[str] = None, | |
) -> NWBFile: | |
"""Append training data from a Labels object to an in-memory NWB file. | |
Args: | |
labels: A general labels object. | |
nwbfile: An in-memory NWB file. | |
pose_estimation_metadata: Metadata for pose estimation. | |
frame_inds: The indices of the frames to write. If None, all frames are written. | |
frame_path: The path to save the frames. If None, the path is the video | |
filename without the extension. | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
+ if not isinstance(labels, Labels): | |
+ raise ValueError("Invalid input: Expected a Labels object.") | |
+ if not isinstance(nwbfile, NWBFile): | |
+ raise ValueError("Invalid input: Expected an NWBFile object.") | |
+ if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict): | |
+ raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.") | |
+ if frame_inds is not None and not isinstance(frame_inds, list): | |
+ raise ValueError("Invalid input: Expected a list for frame_inds.") | |
+ if frame_path is not None and not isinstance(frame_path, str): | |
+ raise ValueError("Invalid input: Expected a string for frame_path.") | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
subject = Subject(subject_id="No specified id", species="No specified species") | |
nwbfile.subject = subject | |
for i, video in enumerate(labels.videos): | |
video_path = ( | |
Path(video.filename) | |
if isinstance(video.filename, str) | |
else video.filename[i] | |
) | |
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" | |
nwb_processing_module = get_processing_module_for_video( | |
processing_module_name, nwbfile | |
) | |
default_metadata["original_videos"] = [f"{video.filename}"] | |
default_metadata["labeled_videos"] = [f"{video.filename}"] | |
default_metadata.update(pose_estimation_metadata) | |
skeletons_list = [ | |
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons | |
] | |
skeletons = Skeletons(skeletons=skeletons_list) | |
nwb_processing_module.add(skeletons) | |
video_info = write_video_to_path( | |
labels.videos[0], frame_inds, frame_path=frame_path | |
) | |
pose_training = labels_to_pose_training(labels, skeletons_list, video_info) | |
nwb_processing_module.add(pose_training) | |
confidence_definition = "Softmax output of the deep neural network" | |
reference_frame = ( | |
"The coordinates are in (x, y) relative to the top-left of the image. " | |
"Coordinates refer to the midpoint of the pixel. " | |
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas " | |
"the top-left corner of that same pixel is at (-0.5, -0.5)." | |
) | |
pose_estimation_series_list = [] | |
for node in skeletons_list[0].nodes: | |
pose_estimation_series = PoseEstimationSeries( | |
name=node, | |
description=f"Marker placed on {node}", | |
data=np.random.rand(100, 2), | |
unit="pixels", | |
reference_frame=reference_frame, | |
timestamps=np.linspace(0, 10, num=100), | |
confidence=np.random.rand(100), | |
confidence_definition=confidence_definition, | |
) | |
pose_estimation_series_list.append(pose_estimation_series) | |
camera = nwbfile.create_device( | |
name=f"camera {i}", | |
description=f"Camera used to record video {i}", | |
manufacturer="No specified manufacturer", | |
) | |
try: | |
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]]) | |
except AttributeError: | |
dimensions = np.array([[400, 400]]) | |
pose_estimation = PoseEstimation( | |
name="pose_estimation", | |
pose_estimation_series=pose_estimation_series_list, | |
description="Estimated positions of the nodes in the video", | |
original_videos=[video.filename for video in labels.videos], | |
labeled_videos=[video.filename for video in labels.videos], | |
dimensions=dimensions, | |
devices=[camera], | |
scorer="No specified scorer", | |
source_software="SLEAP", | |
source_software_version=sleap_version, | |
skeleton=skeletons_list[0], | |
) | |
nwb_processing_module.add(pose_estimation) | |
return nwbfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Outside diff range, codebase verification and nitpick comments (10)
sleap_io/io/nwb.py (10)
61-96
: Enhance error handling for video formats.The function currently raises a
NotImplementedError
for unsupported video formats. Consider adding more descriptive error handling or support for multiple external files.- raise NotImplementedError("Only single-file external videos are supported.") + raise ValueError("Unsupported video format: Only single-file external videos are supported.")
99-115
: Add error handling for invalid skeleton inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.+ if not isinstance(skeleton, NWBSkeleton): + raise ValueError("Invalid input: Expected an NWBSkeleton object.")
117-166
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.") + if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list): + raise ValueError("Invalid input: Expected a list of NWBSkeleton objects.") + if not isinstance(video_info, tuple) or len(video_info) != 3: + raise ValueError("Invalid input: Expected a tuple with three elements (dict, Video, ImageSeries).")Optimize
source_video
handling.The current implementation assumes
source_video
is unique for each frame. Consider optimizing the logic to handle cases wheresource_video
might repeat.- if source_video not in source_video_list: - source_video_list.append(source_video) + if source_video not in source_video_list: + source_video_list.append(source_video) + else: + # Handle repeated source_video cases
168-194
: Add error handling for invalid skeleton inputs.The function lacks error handling for invalid
skeleton
inputs, which can improve robustness and maintainability.+ if not isinstance(skeleton, SLEAPSkeleton): + raise ValueError("Invalid input: Expected a SLEAPSkeleton object.")Optimize edge processing logic.
The current implementation iterates over skeleton edges multiple times. Consider optimizing the logic to reduce complexity.
- skeleton_edges = dict(enumerate(skeleton.nodes)) - for i, source in skeleton_edges.items(): - for destination in list(skeleton_edges.values())[i:]: - if Edge(source, destination) in skeleton.edges: - nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) + node_indices = {node: idx for idx, node in enumerate(skeleton.nodes)} + nwb_edges = [ + [node_indices[edge.source], node_indices[edge.destination]] + for edge in skeleton.edges + ]
197-220
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(instance, Instance): + raise ValueError("Invalid input: Expected an Instance object.") + if not isinstance(skeleton, NWBSkeleton): + raise ValueError("Invalid input: Expected an NWBSkeleton object.")Improve instance naming logic.
Using
id(instance)
for naming can lead to non-intuitive names. Consider using a counter or a more descriptive naming scheme.- name=f"skeleton_instance_{id(instance)}", + name=f"skeleton_instance_{instance.index}", # Assuming `index` is a unique identifier
223-246
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos): + raise ValueError("Invalid input: Expected a list of Video objects.")Use dynamic rate for video.
The function uses a fixed rate of 30.0. Consider using
video.backend.fps
when available.- rate=30.0, # TODO - change to `video.backend.fps` when available + rate=video.backend.fps if hasattr(video.backend, 'fps') else 30.0,
248-317
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(video, Video): + raise ValueError("Invalid input: Expected a Video object.") + if frame_inds is not None and not isinstance(frame_inds, list): + raise ValueError("Invalid input: Expected a list of frame indices.") + if not isinstance(image_format, str): + raise ValueError("Invalid input: Expected a string for image_format.")Use dynamic save path.
The function uses a fixed save path based on the video filename. Consider allowing more flexible path configurations.
- save_path = video.filename.split(".")[0] + save_path = frame_path or video.filename.split(".")[0]
637-738
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.") + if not isinstance(nwbfile, NWBFile): + raise ValueError("Invalid input: Expected an NWBFile object.") + if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict): + raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.") + if frame_inds is not None and not isinstance(frame_inds, list): + raise ValueError("Invalid input: Expected a list for frame_inds.") + if frame_path is not None and not isinstance(frame_path, str): + raise ValueError("Invalid input: Expected a string for frame_path.")Use dynamic metadata.
The function uses fixed metadata for the subject. Consider allowing more flexible metadata configurations.
- subject = Subject(subject_id="No specified id", species="No specified species") + subject = Subject(subject_id=labels.provenance.get("subject_id", "No specified id"), + species=labels.provenance.get("species", "No specified species"))
Line range hint
739-770
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(labels, Labels): + raise ValueError("Invalid input: Expected a Labels object.") + if not isinstance(filename, str): + raise ValueError("Invalid input: Expected a string for filename.") + if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict): + raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.") + if frame_inds is not None and not isinstance(frame_inds, list): + raise ValueError("Invalid input: Expected a list for frame_inds.") + if frame_path is not None and not isinstance(frame_path, str): + raise ValueError("Invalid input: Expected a string for frame_path.") + if as_training is not None and not isinstance(as_training, bool): + raise ValueError("Invalid input: Expected a boolean for as_training.")
Line range hint
876-914
: Add error handling for invalid inputs.The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
+ if not isinstance(track_data_df, pd.DataFrame): + raise ValueError("Invalid input: Expected a pandas DataFrame.") + if not isinstance(timestamps, (np.ndarray, list)): + raise ValueError("Invalid input: Expected an array-like object for timestamps.")Handle non-uniform sampling.
The function assumes uniform sampling for rate calculation. Consider handling non-uniform sampling more explicitly.
- uniform_samples = np.unique(sample_periods.round(5)).size == 1 + uniform_samples = np.allclose(sample_periods, sample_periods[0], atol=0.0001)
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
27-27:
hdmf.utils.LabelledDict
imported but unusedRemove unused import:
hdmf.utils.LabelledDict
(F401)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- pyproject.toml (1 hunks)
- tests/io/test_nwb.py (2 hunks)
Files skipped from review as they are similar to previous changes (2)
- pyproject.toml
- tests/io/test_nwb.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- environment.yml (1 hunks)
- sleap_io/io/nwb.py (17 hunks)
Files skipped from review due to trivial changes (1)
- environment.yml
Additional context used
Ruff
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
sleap_io/io/nwb.py
Outdated
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | ||
"""Converts an NWB skeleton to a SLEAP skeleton. | ||
|
||
Args: | ||
skeleton: An NWB skeleton. | ||
|
||
Returns: | ||
A SLEAP skeleton. | ||
""" | ||
nodes = [Node(name=node) for node in skeleton.nodes] | ||
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | ||
return SLEAPSkeleton( | ||
nodes=nodes, | ||
edges=edges, | ||
name=skeleton.name, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type annotations and error handling in nwb_skeleton_to_sleap
.
The function nwb_skeleton_to_sleap
converts an NWB skeleton to a SLEAP skeleton. The function is straightforward but lacks type annotations and error handling.
Enhance the function with type annotations and error handling:
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
if not isinstance(skeleton, NWBSkeleton):
raise ValueError("Expected an NWBSkeleton object")
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(nodes=nodes, edges=edges, name=skeleton.name)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
Additional comments not posted (1)
sleap_io/io/nwb.py (1)
17-20
: Refactor import handling usingcontextlib.suppress
.To improve the readability and conciseness of the code, consider using
contextlib.suppress
for handling the optional import ofcv2
. This approach is cleaner and reduces the amount of code:-from contextlib import suppress +try: + import cv2 +except ImportError: + pass -with suppress(ImportError): + import cv2Likely invalid or redundant comment.
Tools
Ruff
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
sleap_io/io/nwb.py
Outdated
def instance_to_skeleton_instance( | ||
instance: Instance, skeleton: NWBSkeleton # type: ignore[return] | ||
) -> SkeletonInstance: # type: ignore[return] | ||
"""Converts a SLEAP Instance to an NWB SkeletonInstance. | ||
|
||
Args: | ||
instance: A SLEAP Instance. | ||
skeleton: An NWB Skeleton. | ||
|
||
Returns: | ||
An NWB SkeletonInstance. | ||
""" | ||
points_list = list(instance.points.values()) | ||
node_locs = [[point.x, point.y] for point in points_list] | ||
np_node_locations = np.array(node_locs) | ||
return SkeletonInstance( | ||
name=f"skeleton_instance_{id(instance)}", | ||
id=np.uint64(id(instance)), | ||
node_locations=np_node_locations, | ||
node_visibility=[point.visible for point in instance.points.values()], | ||
skeleton=skeleton, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for input validation in instance_to_skeleton_instance
.
This function should include checks to ensure that the inputs instance
and skeleton
are of the correct types and have the necessary attributes:
def instance_to_skeleton_instance(
instance: Instance, skeleton: NWBSkeleton # type: ignore[return]
) -> SkeletonInstance:
+ if not isinstance(instance, Instance):
+ raise TypeError("Expected an Instance object")
+ if not isinstance(skeleton, NWBSkeleton):
+ raise TypeError("Expected an NWBSkeleton object")
points_list = list(instance.points.values())
...
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def instance_to_skeleton_instance( | |
instance: Instance, skeleton: NWBSkeleton # type: ignore[return] | |
) -> SkeletonInstance: # type: ignore[return] | |
"""Converts a SLEAP Instance to an NWB SkeletonInstance. | |
Args: | |
instance: A SLEAP Instance. | |
skeleton: An NWB Skeleton. | |
Returns: | |
An NWB SkeletonInstance. | |
""" | |
points_list = list(instance.points.values()) | |
node_locs = [[point.x, point.y] for point in points_list] | |
np_node_locations = np.array(node_locs) | |
return SkeletonInstance( | |
name=f"skeleton_instance_{id(instance)}", | |
id=np.uint64(id(instance)), | |
node_locations=np_node_locations, | |
node_visibility=[point.visible for point in instance.points.values()], | |
skeleton=skeleton, | |
) | |
def instance_to_skeleton_instance( | |
instance: Instance, skeleton: NWBSkeleton # type: ignore[return] | |
) -> SkeletonInstance: # type: ignore[return] | |
"""Converts a SLEAP Instance to an NWB SkeletonInstance. | |
Args: | |
instance: A SLEAP Instance. | |
skeleton: An NWB Skeleton. | |
Returns: | |
An NWB SkeletonInstance. | |
""" | |
if not isinstance(instance, Instance): | |
raise TypeError("Expected an Instance object") | |
if not isinstance(skeleton, NWBSkeleton): | |
raise TypeError("Expected an NWBSkeleton object") | |
points_list = list(instance.points.values()) | |
node_locs = [[point.x, point.y] for point in points_list] | |
np_node_locations = np.array(node_locs) | |
return SkeletonInstance( | |
name=f"skeleton_instance_{id(instance)}", | |
id=np.uint64(id(instance)), | |
node_locations=np_node_locations, | |
node_visibility=[point.visible for point in instance.points.values()], | |
skeleton=skeleton, | |
) |
sleap_io/io/nwb.py
Outdated
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | ||
"""Converts an NWB skeleton to a SLEAP skeleton. | ||
|
||
Args: | ||
skeleton: An NWB skeleton. | ||
|
||
Returns: | ||
A SLEAP skeleton. | ||
""" | ||
nodes = [Node(name=node) for node in skeleton.nodes] | ||
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | ||
return SLEAPSkeleton( | ||
nodes=nodes, | ||
edges=edges, | ||
name=skeleton.name, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for input validation in nwb_skeleton_to_sleap
.
This function should include checks to ensure that the input skeleton
is of the correct type and has the necessary attributes:
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+ if not isinstance(skeleton, NWBSkeleton):
+ raise TypeError("Expected an NWBSkeleton object")
+ if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+ raise ValueError("Invalid NWBSkeleton object: Missing 'nodes' or 'edges'")
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
...
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | |
"""Converts an NWB skeleton to a SLEAP skeleton. | |
Args: | |
skeleton: An NWB skeleton. | |
Returns: | |
A SLEAP skeleton. | |
""" | |
nodes = [Node(name=node) for node in skeleton.nodes] | |
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | |
return SLEAPSkeleton( | |
nodes=nodes, | |
edges=edges, | |
name=skeleton.name, | |
) | |
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | |
"""Converts an NWB skeleton to a SLEAP skeleton. | |
Args: | |
skeleton: An NWB skeleton. | |
Returns: | |
A SLEAP skeleton. | |
""" | |
if not isinstance(skeleton, NWBSkeleton): | |
raise TypeError("Expected an NWBSkeleton object") | |
if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'): | |
raise ValueError("Invalid NWBSkeleton object: Missing 'nodes' or 'edges'") | |
nodes = [Node(name=node) for node in skeleton.nodes] | |
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | |
return SLEAPSkeleton( | |
nodes=nodes, | |
edges=edges, | |
name=skeleton.name, | |
) |
sleap_io/io/nwb.py
Outdated
def write_video_to_path( | ||
video: Video, | ||
frame_inds: Optional[list[int]] = None, | ||
image_format: str = "png", | ||
frame_path: Optional[str] = None, | ||
) -> tuple[dict[int, str], Video, ImageSeries]: | ||
"""Write individual frames of a video to a path. | ||
|
||
Args: | ||
video: The video to write. | ||
frame_inds: The indices of the frames to write. If None, all frames are written. | ||
image_format: The format of the image to write. Default is .png | ||
frame_path: The directory to save the frames to. If None, the path is the video | ||
filename without the extension. | ||
|
||
Returns: | ||
A tuple containing a dictionary mapping frame indices to file paths, | ||
the video, and the `ImageSeries`. | ||
""" | ||
index_data = {} | ||
if frame_inds is None: | ||
frame_inds = list(range(video.backend.num_frames)) | ||
|
||
if isinstance(video.filename, list): | ||
save_path = video.filename[0].split(".")[0] | ||
else: | ||
save_path = video.filename.split(".")[0] | ||
|
||
if frame_path is not None: | ||
save_path = frame_path | ||
|
||
try: | ||
os.makedirs(save_path, exist_ok=True) | ||
except PermissionError: | ||
filename_with_extension = video.filename.split("/")[-1] | ||
filename = filename_with_extension.split(".")[0] | ||
save_path = input("Permission denied. Enter a new path:") + "/" + filename | ||
os.makedirs(save_path, exist_ok=True) | ||
|
||
if "cv2" in sys.modules: | ||
for frame_idx in frame_inds: | ||
try: | ||
frame = video[frame_idx] | ||
except FileNotFoundError: | ||
video_filename = input("Video not found. Enter the video filename:") | ||
video = Video.from_filename(video_filename) | ||
frame = video[frame_idx] | ||
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}" | ||
index_data[frame_idx] = frame_path | ||
cv2.imwrite(frame_path, frame) | ||
else: | ||
for frame_idx in frame_inds: | ||
try: | ||
frame = video[frame_idx] | ||
except FileNotFoundError: | ||
video_filename = input("Video not found. Enter the filename:") | ||
video = Video.from_filename(video_filename) | ||
frame = video[frame_idx] | ||
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}" | ||
index_data[frame_idx] = frame_path | ||
iio.imwrite(frame_path, frame) | ||
|
||
image_series = ImageSeries( | ||
name="video", | ||
external_file=os.listdir(save_path), | ||
starting_frame=[0 for _ in range(len(os.listdir(save_path)))], | ||
rate=30.0, # TODO - change to `video.backend.fps` when available | ||
) | ||
return index_data, video, image_series | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor and enhance error handling in write_video_to_path
.
This function handles file operations and user input, which are error-prone areas. Consider refactoring to separate concerns and improve error recovery:
def write_video_to_path(
video: Video,
frame_inds: Optional[list[int]] = None,
image_format: str = "png",
frame_path: Optional[str] = None,
) -> tuple[dict[int, str], Video, ImageSeries]:
index_data = {}
if frame_inds is None:
frame_inds = list(range(video.backend.num_frames))
...
try:
os.makedirs(save_path, exist_ok=True)
except PermissionError:
+ # Consider logging the error or handling it more gracefully
filename_with_extension = video.filename.split("/")[-1]
filename = filename_with_extension.split(".")[0]
save_path = input("Permission denied. Enter a new path:") + "/" + filename
os.makedirs(save_path, exist_ok=True)
...
Consider adding more robust error handling and possibly separating the user input handling into a different function or method to make the code cleaner and more maintainable.
Committable suggestion was skipped due to low confidence.
sleap_io/io/nwb.py
Outdated
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] | ||
"""Creates a Labels object from an NWB PoseTraining object. | ||
|
||
Args: | ||
pose_training: An NWB PoseTraining object. | ||
|
||
Returns: | ||
A Labels object. | ||
""" | ||
labeled_frames = [] | ||
skeletons = {} | ||
training_frames = pose_training.training_frames.training_frames.values() | ||
for training_frame in training_frames: | ||
source_video = training_frame.source_video | ||
video = Video(source_video.external_file) | ||
|
||
frame_idx = training_frame.source_video_frame_index | ||
instances = [] | ||
for instance in training_frame.skeleton_instances.skeleton_instances.values(): | ||
if instance.skeleton.name not in skeletons: | ||
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap( | ||
instance.skeleton | ||
) | ||
skeleton = skeletons[instance.skeleton.name] | ||
instances.append( | ||
Instance.from_numpy( | ||
points=instance.node_locations[:], skeleton=skeleton | ||
) | ||
) | ||
labeled_frames.append( | ||
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) | ||
) | ||
return Labels(labeled_frames=labeled_frames) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for input validation in pose_training_to_labels
.
The function lacks checks for the validity of the pose_training
input. Adding error handling for the input type and ensuring it contains the necessary attributes would enhance the robustness of the function:
def pose_training_to_labels(pose_training: PoseTraining) -> Labels:
+ if not isinstance(pose_training, PoseTraining):
+ raise TypeError("Expected a PoseTraining object")
+ if not hasattr(pose_training, 'training_frames') or not isinstance(pose_training.training_frames, TrainingFrames):
+ raise ValueError("Invalid PoseTraining object: Missing 'training_frames'")
labeled_frames = []
skeletons = {}
training_frames = pose_training.training_frames.training_frames.values()
...
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] | |
"""Creates a Labels object from an NWB PoseTraining object. | |
Args: | |
pose_training: An NWB PoseTraining object. | |
Returns: | |
A Labels object. | |
""" | |
labeled_frames = [] | |
skeletons = {} | |
training_frames = pose_training.training_frames.training_frames.values() | |
for training_frame in training_frames: | |
source_video = training_frame.source_video | |
video = Video(source_video.external_file) | |
frame_idx = training_frame.source_video_frame_index | |
instances = [] | |
for instance in training_frame.skeleton_instances.skeleton_instances.values(): | |
if instance.skeleton.name not in skeletons: | |
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap( | |
instance.skeleton | |
) | |
skeleton = skeletons[instance.skeleton.name] | |
instances.append( | |
Instance.from_numpy( | |
points=instance.node_locations[:], skeleton=skeleton | |
) | |
) | |
labeled_frames.append( | |
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) | |
) | |
return Labels(labeled_frames=labeled_frames) | |
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] | |
"""Creates a Labels object from an NWB PoseTraining object. | |
Args: | |
pose_training: An NWB PoseTraining object. | |
Returns: | |
A Labels object. | |
""" | |
if not isinstance(pose_training, PoseTraining): | |
raise TypeError("Expected a PoseTraining object") | |
if not hasattr(pose_training, 'training_frames') or not isinstance(pose_training.training_frames, TrainingFrames): | |
raise ValueError("Invalid PoseTraining object: Missing 'training_frames'") | |
labeled_frames = [] | |
skeletons = {} | |
training_frames = pose_training.training_frames.training_frames.values() | |
for training_frame in training_frames: | |
source_video = training_frame.source_video | |
video = Video(source_video.external_file) | |
frame_idx = training_frame.source_video_frame_index | |
instances = [] | |
for instance in training_frame.skeleton_instances.skeleton_instances.values(): | |
if instance.skeleton.name not in skeletons: | |
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap( | |
instance.skeleton | |
) | |
skeleton = skeletons[instance.skeleton.name] | |
instances.append( | |
Instance.from_numpy( | |
points=instance.node_locations[:], skeleton=skeleton | |
) | |
) | |
labeled_frames.append( | |
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) | |
) | |
return Labels(labeled_frames=labeled_frames) |
sleap_io/io/nwb.py
Outdated
def labels_to_pose_training( | ||
labels: Labels, | ||
skeletons_list: list[NWBSkeleton], # type: ignore[return] | ||
video_info: tuple[dict[int, str], Video, ImageSeries], | ||
) -> PoseTraining: # type: ignore[return] | ||
"""Creates an NWB PoseTraining object from a Labels object. | ||
|
||
Args: | ||
labels: A Labels object. | ||
skeletons_list: A list of NWB skeletons. | ||
video_info: A tuple containing a dictionary mapping frame indices to file paths, | ||
the video, and the `ImageSeries`. | ||
|
||
Returns: | ||
A PoseTraining object. | ||
""" | ||
training_frame_list = [] | ||
skeleton_instances_list = [] | ||
source_video_list = [] | ||
for i, labeled_frame in enumerate(labels.labeled_frames): | ||
for instance, skeleton in zip(labeled_frame.instances, skeletons_list): | ||
skeleton_instance = instance_to_skeleton_instance(instance, skeleton) | ||
skeleton_instances_list.append(skeleton_instance) | ||
|
||
training_frame_skeleton_instances = SkeletonInstances( | ||
skeleton_instances=skeleton_instances_list | ||
) | ||
training_frame_video_index = labeled_frame.frame_idx | ||
|
||
image_series = video_info[2] | ||
source_video = image_series | ||
if source_video not in source_video_list: | ||
source_video_list.append(source_video) | ||
training_frame = TrainingFrame( | ||
name=f"training_frame_{i}", | ||
annotator="N/A", | ||
skeleton_instances=training_frame_skeleton_instances, | ||
source_video=source_video, | ||
source_video_frame_index=training_frame_video_index, | ||
) | ||
training_frame_list.append(training_frame) | ||
|
||
training_frames = TrainingFrames(training_frames=training_frame_list) | ||
source_videos = SourceVideos(image_series=source_video_list) | ||
pose_training = PoseTraining( | ||
training_frames=training_frames, | ||
source_videos=source_videos, | ||
) | ||
return pose_training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve error handling and refactor labels_to_pose_training
for clarity.
This function handles multiple complex operations and would benefit from clearer error handling and possibly some refactoring to improve readability:
def labels_to_pose_training(
labels: Labels,
skeletons_list: list[NWBSkeleton], # type: ignore[return]
video_info: tuple[dict[int, str], Video, ImageSeries],
) -> PoseTraining:
+ if not isinstance(labels, Labels):
+ raise TypeError("Expected a Labels object")
+ if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list):
+ raise ValueError("Expected a list of NWBSkeleton objects")
+ if not isinstance(video_info, tuple) or len(video_info) != 3:
+ raise ValueError("Expected a tuple containing (dict, Video, ImageSeries)")
...
Consider refactoring to separate some of the nested logic into helper functions to improve the function's readability and maintainability.
Committable suggestion was skipped due to low confidence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Outside diff range, codebase verification and nitpick comments (2)
sleap_io/io/nwb.py (2)
Line range hint
698-729
: Consider simplifying the function.The function
append_nwb
handles multiple parameters and data structures, which adds complexity. Consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.
Line range hint
764-825
: Consider simplifying the function.The function
build_pose_estimation_container_for_track
handles multiple parameters and data structures, which adds complexity. Consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
157-157: Use
enumerate()
for index variableinstance_counter
infor
loop(SIM113)
Additional comments not posted (7)
sleap_io/io/nwb.py (7)
186-218
: LGTM!The function
pose_training_to_labels
is well-implemented and handles the conversion fromPoseTraining
toLabels
effectively.
221-236
: LGTM!The function
nwb_skeleton_to_sleap
correctly converts NWB skeleton data to SLEAP format without any apparent issues.
268-291
: LGTM!The function
instance_to_skeleton_instance
is well-implemented and handles the conversion from SLEAP instance data to NWB format effectively.
294-316
: LGTM!The function
videos_to_source_videos
correctly converts video data to NWB format without any apparent issues.
131-183
: Confirm error handling and consider simplifying.The function
labels_to_pose_training
is complex, handling multiple layers of data aggregation. Ensure that the error handling suggested in previous reviews has been incorporated to improve robustness.Additionally, consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.
Tools
Ruff
157-157: Use
enumerate()
for index variableinstance_counter
infor
loop(SIM113)
239-265
: Confirm error handling and consider simplifying.The function
slp_skeleton_to_nwb
converts SLEAP data to NWB format. Ensure that the error handling suggested in previous reviews has been incorporated to improve robustness.Additionally, consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.
637-696
: Confirm error handling and consider simplifying.The function
append_nwb_training
is complex, handling multiple layers of data aggregation. Ensure that the error handling suggested in previous reviews has been incorporated to improve robustness.Additionally, consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
🧹 Outside diff range and nitpick comments (5)
README.md (1)
44-48
: Excellent addition to the documentation!The new example for saving labels to an NWB file with training data conversion is a valuable addition that aligns well with the PR objectives. It clearly demonstrates how to use the new
as_training
andframe_inds
parameters.A few suggestions to further improve this section:
Consider adding a link to more detailed documentation about these new parameters, especially explaining the significance of
as_training=True
and how users should chooseframe_inds
.It might be helpful to clarify the purpose of saving individual frames. For example, you could expand on why this enhances portability and in what scenarios this would be particularly useful.
Here's a suggested expansion of the comment:
# Save to an NWB file and convert SLEAP training data to NWB training data. -# Note: This will save the first 3 frames of the video as individual images in a -# subfolder next to the NWB file for portability. +# Note: This will save the first 3 frames of the video as individual images in a +# subfolder next to the NWB file. This enhances portability by allowing users to +# visualize key frames without needing access to the original video file, which +# can be particularly useful for sharing datasets or when working with large video files.docs/index.md (1)
48-52
: LGTM! Consider adding a brief explanation of NWB format.The new code block and comments effectively demonstrate how to save SLEAP training data to an NWB file with additional options. The example is clear and informative.
To improve clarity for users who might not be familiar with the NWB format, consider adding a brief explanation or link to more information about NWB (Neurodata Without Borders) format. For example:
# Save to NWB file. labels.save("predictions.nwb") + +# NWB (Neurodata Without Borders) is a standardized neurophysiology data format. +# For more information, visit: https://www.nwb.org/ # Save to an NWB file and convert SLEAP training data to NWB training data.sleap_io/io/main.py (2)
84-91
: Clarify the usage offrame_inds
,frame_path
, andimage_format
parameters.The new parameters
frame_inds
,frame_path
, andimage_format
are only applicable whenas_training
isTrue
. However, this is not immediately clear from the parameter names or the docstring. Consider renaming these parameters and updating the docstring to make their usage more explicit.Here's a suggestion:
def save_nwb( labels: Labels, filename: str, as_training: bool = False, append: bool = True, training_frame_indices: Optional[list[int]] = None, training_frame_path: Optional[str] = None, training_image_format: str = "png", ): """Save a SLEAP dataset to NWB format. Args: ... training_frame_indices: Optional list of labeled frame indices within the Labels to save when saving in training data format. No effect if `as_training` is `False`. training_frame_path: The path to a folder to save the extracted frame images to when saving in training data format. No effect if `as_training` is `False`. training_image_format: The image format to use when saving extracted frame images. No effect if `as_training` is `False`. ... """ ...
229-230
: Handle unknown formats consistently.The
load_file
function raises aValueError
with a specific error message when an unknown format is encountered. However, the error message in thesave_file
function is different. Consider using a consistent error message for both functions.Here's a suggestion:
else: raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")sleap_io/io/nwb.py (1)
339-339
: Complete the TODO comment forvideo.backend.fps
.The function contains a TODO comment indicating the need to change the rate to
video.backend.fps
when available.Do you want me to assist in completing this functionality or open a GitHub issue to track this task?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
- README.md (1 hunks)
- docs/index.md (1 hunks)
- pyproject.toml (1 hunks)
- sleap_io/io/main.py (3 hunks)
- sleap_io/io/nwb.py (17 hunks)
🧰 Additional context used
Ruff
sleap_io/io/nwb.py
11-11:
os
imported but unusedRemove unused import:
os
(F401)
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
185-185: Use
enumerate()
for index variableinstance_counter
infor
loop(SIM113)
🔇 Additional comments not posted (5)
sleap_io/io/nwb.py (5)
60-104
: LGTM!The
save_frame_images
function is well-structured and handles the saving of frame images correctly. The error handling for missing video files is marked with a TODO comment, which is appropriate for future implementation.
107-153
: LGTM!The
make_image_series
function correctly creates NWBImageSeries
objects from the saved frame images. The grouping of labeled frames by video and the mapping of labeled frame indices to series indices are implemented properly.
156-209
: LGTM!The
labels_to_pose_training
function correctly converts SLEAPLabels
to NWBPoseTraining
. The creation ofTrainingFrame
andSkeletonInstance
objects is handled properly, and thePoseTraining
object is constructed with the necessary components.🧰 Tools
Ruff
185-185: Use
enumerate()
for index variableinstance_counter
infor
loop(SIM113)
212-244
: LGTM!The
pose_training_to_labels
function correctly converts NWBPoseTraining
back to SLEAPLabels
. The reconstruction ofLabeledFrame
andInstance
objects is implemented properly.
872-872
:⚠️ Potential issueAdd error handling for invalid inputs.
The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
def build_track_pose_estimation_list( track_data_df: pd.DataFrame, timestamps: ArrayLike ) -> List[PoseEstimationSeries]: + if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, (np.ndarray, list)): + raise ValueError("Invalid input: Expected a pandas DataFrame and an array-like object.")Likely invalid or redundant comment.
# TODO: Replace with ndx-pose>=0.2.0 once released. | ||
# https://github.com/rly/ndx-pose/issues/29 | ||
"ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Dependency points to an outdated commit
The ndx-pose
dependency is currently set to commit a847ad4be75e60ef9e413b8cbfc99c616fc9fd05
, while the latest commit on the main branch is 8017291ae9aee84d9f0b6f1d9426b56372cd715b
. To ensure you have the most recent updates and fixes, consider updating the dependency to the latest commit until version 0.2.0
is released on PyPI.
- Update the commit hash to
8017291ae9aee84d9f0b6f1d9426b56372cd715b
. - Monitor the PyPI repository for the release of
ndx-pose>=0.2.0
and update the dependency accordingly.
🔗 Analysis chain
Temporary dependency solution needs follow-up
The addition of the ndx-pose
dependency from a specific GitHub commit is a temporary solution. While this allows for immediate use of the required features, it comes with some considerations:
- Reproducibility: Using a specific commit ensures consistency but may cause issues if the repository is altered or becomes unavailable.
- Maintenance: This approach requires manual updates when new versions are released.
- Security: It bypasses the usual package vetting process on PyPI.
To ensure this temporary solution doesn't persist longer than necessary, let's check the current state of the ndx-pose
package:
Consider implementing a CI/CD process to automatically check for the availability of ndx-pose>=0.2.0
on PyPI and create a pull request to update the dependency when it becomes available.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if ndx-pose>=0.2.0 is available on PyPI
pip install ndx-pose>=0.2.0 2>&1 || echo "ndx-pose>=0.2.0 is not yet available on PyPI"
# Check the latest commit on the main branch of ndx-pose repository
latest_commit=$(gh api repos/rly/ndx-pose/commits/main --jq '.sha')
echo "Latest commit on main branch: $latest_commit"
echo "Current commit in use: a847ad4be75e60ef9e413b8cbfc99c616fc9fd05"
Length of output: 518
def save_nwb( | ||
labels: Labels, | ||
filename: str, | ||
as_training: bool = False, | ||
append: bool = True, | ||
frame_inds: Optional[list[int]] = None, | ||
frame_path: Optional[str] = None, | ||
image_format: str = "png", | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Refactor save_nwb
to reduce complexity and improve readability.
The function signature has been updated to include several new parameters. However, the function body can be refactored to reduce complexity and improve readability. Consider the following suggestions:
- Use a dictionary to map the
as_training
flag to the correspondingnwb
functions. This will eliminate the need for theif-else
block. - Move the docstring for the
as_training
parameter to a single line to improve readability. - Consider using more descriptive variable names for
frame_inds
andframe_path
.
Here's a refactored version of the function:
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = False,
append: bool = True,
training_frame_indices: Optional[list[int]] = None,
training_frame_path: Optional[str] = None,
training_image_format: str = "png",
):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be created if it does not exist.
training_frame_indices: Optional list of labeled frame indices within the Labels to save when saving in training data format.
training_frame_path: The path to a folder to save the extracted frame images to when saving in training data format.
training_image_format: The image format to use when saving extracted frame images.
See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training
"""
nwb_funcs = {
True: {
"append": nwb.append_nwb_training,
"write": nwb.write_nwb_training,
},
False: {
"append": nwb.append_nwb,
"write": nwb.write_nwb,
},
}
action = "append" if append and Path(filename).exists() else "write"
nwb_funcs[as_training][action](
labels,
filename,
frame_inds=training_frame_indices,
frame_path=training_frame_path,
image_format=training_image_format,
)
See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training | ||
""" | ||
if append and Path(filename).exists(): | ||
nwb.append_nwb(labels, filename) | ||
nwb.append_nwb( | ||
labels, | ||
filename, | ||
as_training=as_training, | ||
frame_inds=frame_inds, | ||
frame_path=frame_path, | ||
image_format=image_format, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Refactor the if-else
block to reduce duplication.
The if-else
block in the save_nwb
function contains duplicated code. The only difference between the two branches is the function being called (nwb.append_nwb
or nwb.write_nwb
). This duplication can be eliminated by using a dictionary to map the append
flag to the corresponding function.
Here's a refactored version:
nwb_funcs = {
True: {
"append": nwb.append_nwb_training,
"write": nwb.write_nwb_training,
},
False: {
"append": nwb.append_nwb,
"write": nwb.write_nwb,
},
}
action = "append" if append and Path(filename).exists() else "write"
nwb_funcs[as_training][action](
labels,
filename,
frame_inds=training_frame_indices,
frame_path=training_frame_path,
image_format=training_image_format,
)
elif format in ("nwb", "nwb_predictions"): | ||
save_nwb(labels, filename, False) | ||
elif format == "nwb_training": | ||
save_nwb(labels, filename, True, frame_inds=kwargs.get("frame_inds", None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use a dictionary to map formats to save functions.
The save_file
function uses a series of if-elif
statements to determine which save function to call based on the format
parameter. This can be simplified by using a dictionary to map formats to their corresponding save functions.
Here's a refactored version:
format_map = {
"slp": save_slp,
"nwb": lambda l, f, **kw: save_nwb(l, f, False, **kw),
"nwb_training": lambda l, f, **kw: save_nwb(l, f, True, frame_inds=kw.pop("frame_inds", None), **kw),
"labelstudio": save_labelstudio,
"jabs": lambda l, f, **kw: save_jabs(l, kw.pop("pose_version", 5), f, **kw),
}
if format in format_map:
format_map[format](labels, filename, **kwargs)
else:
raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")
try: | ||
import cv2 | ||
except ImportError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use contextlib.suppress
for import suppression.
Replace the try
-except
-pass
block with contextlib.suppress
to improve readability.
-try:
- import cv2
-except ImportError:
- pass
+from contextlib import suppress
+with suppress(ImportError):
+ import cv2
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
try: | |
import cv2 | |
except ImportError: | |
pass | |
from contextlib import suppress | |
with suppress(ImportError): | |
import cv2 |
🧰 Tools
Ruff
17-20: Use
contextlib.suppress(ImportError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(ImportError)
(SIM105)
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return] | ||
"""Converts a list of SLEAP Videos to NWB SourceVideos. | ||
|
||
Args: | ||
videos: A list of SLEAP Videos. | ||
|
||
Returns: | ||
An NWB SourceVideos object. | ||
""" | ||
source_videos = [] | ||
for i, video in enumerate(videos): | ||
image_series = ImageSeries( | ||
name=f"video_{i}", | ||
description="N/A", | ||
unit="NA", | ||
format="external", | ||
external_file=[video.filename], | ||
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | ||
starting_frame=[0], | ||
rate=30.0, # TODO - change to `video.backend.fps` when available | ||
) | ||
source_videos.append(image_series) | ||
return SourceVideos(image_series=source_videos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for input validation in videos_to_source_videos
.
This function should include checks to ensure that the input videos
is a list of Video
objects:
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return]
+ if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+ raise TypeError("Expected a list of Video objects")
source_videos = []
...
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return] | |
"""Converts a list of SLEAP Videos to NWB SourceVideos. | |
Args: | |
videos: A list of SLEAP Videos. | |
Returns: | |
An NWB SourceVideos object. | |
""" | |
source_videos = [] | |
for i, video in enumerate(videos): | |
image_series = ImageSeries( | |
name=f"video_{i}", | |
description="N/A", | |
unit="NA", | |
format="external", | |
external_file=[video.filename], | |
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | |
starting_frame=[0], | |
rate=30.0, # TODO - change to `video.backend.fps` when available | |
) | |
source_videos.append(image_series) | |
return SourceVideos(image_series=source_videos) | |
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return] | |
"""Converts a list of SLEAP Videos to NWB SourceVideos. | |
Args: | |
videos: A list of SLEAP Videos. | |
Returns: | |
An NWB SourceVideos object. | |
""" | |
if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos): | |
raise TypeError("Expected a list of Video objects") | |
source_videos = [] | |
for i, video in enumerate(videos): | |
image_series = ImageSeries( | |
name=f"video_{i}", | |
description="N/A", | |
unit="NA", | |
format="external", | |
external_file=[video.filename], | |
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], | |
starting_frame=[0], | |
rate=30.0, # TODO - change to `video.backend.fps` when available | |
) | |
source_videos.append(image_series) | |
return SourceVideos(image_series=source_videos) |
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] | ||
"""Converts an NWB skeleton to a SLEAP skeleton. | ||
|
||
Args: | ||
skeleton: An NWB skeleton. | ||
|
||
Returns: | ||
A SLEAP skeleton. | ||
""" | ||
nodes = [Node(name=node) for node in skeleton.nodes] | ||
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] | ||
return SLEAPSkeleton( | ||
nodes=nodes, | ||
edges=edges, | ||
name=skeleton.name, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and type annotations to nwb_skeleton_to_sleap
.
The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+ if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+ raise ValueError("Invalid skeleton: Missing required data.")
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)
Committable suggestion was skipped due to low confidence.
def slp_skeleton_to_nwb( | ||
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None | ||
) -> NWBSkeleton: # type: ignore[return] | ||
"""Converts SLEAP skeleton to NWB skeleton. | ||
|
||
Args: | ||
skeleton: A SLEAP skeleton. | ||
subject: An NWB subject. | ||
|
||
Returns: | ||
An NWB skeleton. | ||
""" | ||
if subject is None: | ||
subject = Subject(species="No specified species", subject_id="No specified id") | ||
nwb_edges = [] | ||
skeleton_edges = dict(enumerate(skeleton.nodes)) | ||
for i, source in skeleton_edges.items(): | ||
for destination in list(skeleton_edges.values())[i:]: | ||
if Edge(source, destination) in skeleton.edges: | ||
nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) | ||
|
||
return NWBSkeleton( | ||
name=skeleton.name, | ||
nodes=skeleton.node_names, | ||
edges=np.array(nwb_edges, dtype=np.uint8), | ||
subject=subject, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for input validation in slp_skeleton_to_nwb
.
This function should include checks to ensure that the input skeleton
is of the correct type and has the necessary attributes:
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton:
+ if not isinstance(skeleton, SLEAPSkeleton):
+ raise TypeError("Expected a SLEAPSkeleton object")
+ if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+ raise ValueError("Invalid SLEAPSkeleton object: Missing 'nodes' or 'edges'")
if subject is None:
subject = Subject(species="No specified species", subject_id="No specified id")
...
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def slp_skeleton_to_nwb( | |
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None | |
) -> NWBSkeleton: # type: ignore[return] | |
"""Converts SLEAP skeleton to NWB skeleton. | |
Args: | |
skeleton: A SLEAP skeleton. | |
subject: An NWB subject. | |
Returns: | |
An NWB skeleton. | |
""" | |
if subject is None: | |
subject = Subject(species="No specified species", subject_id="No specified id") | |
nwb_edges = [] | |
skeleton_edges = dict(enumerate(skeleton.nodes)) | |
for i, source in skeleton_edges.items(): | |
for destination in list(skeleton_edges.values())[i:]: | |
if Edge(source, destination) in skeleton.edges: | |
nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) | |
return NWBSkeleton( | |
name=skeleton.name, | |
nodes=skeleton.node_names, | |
edges=np.array(nwb_edges, dtype=np.uint8), | |
subject=subject, | |
) | |
def slp_skeleton_to_nwb( | |
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None | |
) -> NWBSkeleton: # type: ignore[return] | |
"""Converts SLEAP skeleton to NWB skeleton. | |
Args: | |
skeleton: A SLEAP skeleton. | |
subject: An NWB subject. | |
Returns: | |
An NWB skeleton. | |
""" | |
if not isinstance(skeleton, SLEAPSkeleton): | |
raise TypeError("Expected a SLEAPSkeleton object") | |
if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'): | |
raise ValueError("Invalid SLEAPSkeleton object: Missing 'nodes' or 'edges'") | |
if subject is None: | |
subject = Subject(species="No specified species", subject_id="No specified id") | |
nwb_edges = [] | |
skeleton_edges = dict(enumerate(skeleton.nodes)) | |
for i, source in skeleton_edges.items(): | |
for destination in list(skeleton_edges.values())[i:]: | |
if Edge(source, destination) in skeleton.edges: | |
nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) | |
return NWBSkeleton( | |
name=skeleton.name, | |
nodes=skeleton.node_names, | |
edges=np.array(nwb_edges, dtype=np.uint8), | |
subject=subject, | |
) |
def instance_to_skeleton_instance( | ||
instance: Instance, | ||
skeleton: NWBSkeleton, # type: ignore[return] | ||
counter: int, | ||
) -> SkeletonInstance: # type: ignore[return] | ||
"""Converts a SLEAP Instance to an NWB SkeletonInstance. | ||
|
||
Args: | ||
instance: A SLEAP Instance. | ||
skeleton: An NWB Skeleton. | ||
counter: An integer counter. | ||
|
||
Returns: | ||
An NWB SkeletonInstance. | ||
""" | ||
points_list = list(instance.points.values()) | ||
node_locations = np.array([[point.x, point.y] for point in points_list]) | ||
return SkeletonInstance( | ||
name=f"skeleton_instance_{counter}", | ||
id=np.uint64(counter), | ||
node_locations=node_locations, | ||
node_visibility=[point.visible for point in instance.points.values()], | ||
skeleton=skeleton, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for input validation in instance_to_skeleton_instance
.
This function should include checks to ensure that the inputs instance
and skeleton
are of the correct types and have the necessary attributes:
def instance_to_skeleton_instance(
instance: Instance, skeleton: NWBSkeleton # type: ignore[return]
) -> SkeletonInstance:
+ if not isinstance(instance, Instance):
+ raise TypeError("Expected an Instance object")
+ if not isinstance(skeleton, NWBSkeleton):
+ raise TypeError("Expected an NWBSkeleton object")
points_list = list(instance.points.values())
...
Committable suggestion was skipped due to low confidence.
def append_nwb_training( | ||
labels: Labels, | ||
nwbfile: NWBFile, | ||
pose_estimation_metadata: Optional[dict] = None, | ||
frame_inds: Optional[list[int]] = None, | ||
frame_path: Optional[str] = None, | ||
image_format: str = "png", | ||
) -> NWBFile: | ||
"""Append training data from a Labels object to an in-memory NWB file. | ||
|
||
Args: | ||
labels: A general labels object. | ||
nwbfile: An in-memory NWB file. | ||
pose_estimation_metadata: Metadata for pose estimation. | ||
frame_inds: The indices of the frames to write. If None, all frames are written. | ||
frame_path: The path to save the frames. If None, the path is the video | ||
filename without the extension. | ||
image_format: The format of the image to write. Default is "png". | ||
|
||
Returns: | ||
An in-memory NWB file with the PoseTraining data appended. | ||
""" | ||
pose_estimation_metadata = pose_estimation_metadata or dict() | ||
provenance = labels.provenance | ||
default_metadata = dict(scorer=str(provenance)) | ||
sleap_version = provenance.get("sleap_version", None) | ||
default_metadata["source_software_version"] = sleap_version | ||
|
||
subject = Subject(subject_id="No specified id", species="No specified species") | ||
nwbfile.subject = subject | ||
|
||
for i, video in enumerate(labels.videos): | ||
video_path = ( | ||
Path(video.filename) | ||
if isinstance(video.filename, str) | ||
else video.filename[i] | ||
) | ||
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" | ||
nwb_processing_module = get_processing_module_for_video( | ||
processing_module_name, nwbfile | ||
) | ||
default_metadata["original_videos"] = [f"{video.filename}"] | ||
default_metadata["labeled_videos"] = [f"{video.filename}"] | ||
default_metadata.update(pose_estimation_metadata) | ||
|
||
skeletons_list = [ | ||
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons | ||
] | ||
skeletons = Skeletons(skeletons=skeletons_list) | ||
nwb_processing_module.add(skeletons) | ||
lf_ind_to_img_path = save_frame_images( | ||
labels, | ||
save_path=frame_path, | ||
image_format=image_format, | ||
frame_inds=frame_inds, | ||
) | ||
image_series_list, lf_ind_to_series_ind = make_image_series( | ||
labels, lf_ind_to_img_path | ||
) | ||
pose_training = labels_to_pose_training( | ||
labels, skeletons_list, image_series_list, lf_ind_to_series_ind | ||
) | ||
nwb_processing_module.add(pose_training) | ||
|
||
_ = nwbfile.create_device( | ||
name=f"camera {i}", | ||
description=f"Camera used to record video {i}", | ||
manufacturer="No specified manufacturer", | ||
) | ||
return nwbfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for invalid inputs.
The function lacks error handling for invalid inputs, which can improve robustness and maintainability.
def append_nwb_training(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
) -> NWBFile:
+ if not isinstance(labels, Labels):
+ raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(nwbfile, NWBFile):
+ raise ValueError("Invalid input: Expected an NWBFile object.")
+ if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+ raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+ if frame_inds is not None and not isinstance(frame_inds, list):
+ raise ValueError("Invalid input: Expected a list for frame_inds.")
+ if frame_path is not None and not isinstance(frame_path, str):
+ raise ValueError("Invalid input: Expected a string for frame_path.")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def append_nwb_training( | |
labels: Labels, | |
nwbfile: NWBFile, | |
pose_estimation_metadata: Optional[dict] = None, | |
frame_inds: Optional[list[int]] = None, | |
frame_path: Optional[str] = None, | |
image_format: str = "png", | |
) -> NWBFile: | |
"""Append training data from a Labels object to an in-memory NWB file. | |
Args: | |
labels: A general labels object. | |
nwbfile: An in-memory NWB file. | |
pose_estimation_metadata: Metadata for pose estimation. | |
frame_inds: The indices of the frames to write. If None, all frames are written. | |
frame_path: The path to save the frames. If None, the path is the video | |
filename without the extension. | |
image_format: The format of the image to write. Default is "png". | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
subject = Subject(subject_id="No specified id", species="No specified species") | |
nwbfile.subject = subject | |
for i, video in enumerate(labels.videos): | |
video_path = ( | |
Path(video.filename) | |
if isinstance(video.filename, str) | |
else video.filename[i] | |
) | |
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" | |
nwb_processing_module = get_processing_module_for_video( | |
processing_module_name, nwbfile | |
) | |
default_metadata["original_videos"] = [f"{video.filename}"] | |
default_metadata["labeled_videos"] = [f"{video.filename}"] | |
default_metadata.update(pose_estimation_metadata) | |
skeletons_list = [ | |
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons | |
] | |
skeletons = Skeletons(skeletons=skeletons_list) | |
nwb_processing_module.add(skeletons) | |
lf_ind_to_img_path = save_frame_images( | |
labels, | |
save_path=frame_path, | |
image_format=image_format, | |
frame_inds=frame_inds, | |
) | |
image_series_list, lf_ind_to_series_ind = make_image_series( | |
labels, lf_ind_to_img_path | |
) | |
pose_training = labels_to_pose_training( | |
labels, skeletons_list, image_series_list, lf_ind_to_series_ind | |
) | |
nwb_processing_module.add(pose_training) | |
_ = nwbfile.create_device( | |
name=f"camera {i}", | |
description=f"Camera used to record video {i}", | |
manufacturer="No specified manufacturer", | |
) | |
return nwbfile | |
def append_nwb_training( | |
labels: Labels, | |
nwbfile: NWBFile, | |
pose_estimation_metadata: Optional[dict] = None, | |
frame_inds: Optional[list[int]] = None, | |
frame_path: Optional[str] = None, | |
image_format: str = "png", | |
) -> NWBFile: | |
if not isinstance(labels, Labels): | |
raise ValueError("Invalid input: Expected a Labels object.") | |
if not isinstance(nwbfile, NWBFile): | |
raise ValueError("Invalid input: Expected an NWBFile object.") | |
if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict): | |
raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.") | |
if frame_inds is not None and not isinstance(frame_inds, list): | |
raise ValueError("Invalid input: Expected a list for frame_inds.") | |
if frame_path is not None and not isinstance(frame_path, str): | |
raise ValueError("Invalid input: Expected a string for frame_path.") | |
"""Append training data from a Labels object to an in-memory NWB file. | |
Args: | |
labels: A general labels object. | |
nwbfile: An in-memory NWB file. | |
pose_estimation_metadata: Metadata for pose estimation. | |
frame_inds: The indices of the frames to write. If None, all frames are written. | |
frame_path: The path to save the frames. If None, the path is the video | |
filename without the extension. | |
image_format: The format of the image to write. Default is "png". | |
Returns: | |
An in-memory NWB file with the PoseTraining data appended. | |
""" | |
pose_estimation_metadata = pose_estimation_metadata or dict() | |
provenance = labels.provenance | |
default_metadata = dict(scorer=str(provenance)) | |
sleap_version = provenance.get("sleap_version", None) | |
default_metadata["source_software_version"] = sleap_version | |
subject = Subject(subject_id="No specified id", species="No specified species") | |
nwbfile.subject = subject | |
for i, video in enumerate(labels.videos): | |
video_path = ( | |
Path(video.filename) | |
if isinstance(video.filename, str) | |
else video.filename[i] | |
) | |
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" | |
nwb_processing_module = get_processing_module_for_video( | |
processing_module_name, nwbfile | |
) | |
default_metadata["original_videos"] = [f"{video.filename}"] | |
default_metadata["labeled_videos"] = [f"{video.filename}"] | |
default_metadata.update(pose_estimation_metadata) | |
skeletons_list = [ | |
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons | |
] | |
skeletons = Skeletons(skeletons=skeletons_list) | |
nwb_processing_module.add(skeletons) | |
lf_ind_to_img_path = save_frame_images( | |
labels, | |
save_path=frame_path, | |
image_format=image_format, | |
frame_inds=frame_inds, | |
) | |
image_series_list, lf_ind_to_series_ind = make_image_series( | |
labels, lf_ind_to_img_path | |
) | |
pose_training = labels_to_pose_training( | |
labels, skeletons_list, image_series_list, lf_ind_to_series_ind | |
) | |
nwb_processing_module.add(pose_training) | |
_ = nwbfile.create_device( | |
name=f"camera {i}", | |
description=f"Camera used to record video {i}", | |
manufacturer="No specified manufacturer", | |
) | |
return nwbfile |
Description
Right now, data from NWB (Neurodata without Borders) is partially supported in
sleap-io
. ThePoseEstimation
andPoseEstimationSeries
data structures are supported, but theTrainingFrame
,TrainingFrames
,PoseTraining
, andSourceVideos
structures are not. These data structures correspond to data structures in SLEAP as shown in rly/ndx-pose#24. I have added support for these by allowing the user to export SLEAP training data as NWB training data.I have also updated the README with an example of how to use this feature.
Types of changes
Does this address any currently open issues?
#100, #86, rly/ndx-pose#29
Summary by CodeRabbit
New Features
Bug Fixes
Refactor