Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH+RF: Allow precomputed derivatives in T1w or T2w space #305

Merged
merged 19 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 122 additions & 63 deletions nibabies/utils/bids.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,134 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Utilities to handle BIDS inputs."""
from __future__ import annotations

import json
import os
import sys
import typing as ty
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import IO, List, Literal, Optional, Union

import nibabel as nb
import numpy as np
from bids.layout import BIDSLayout, Query

_spec: dict = {
't1w_mask': {
'datatype': 'anat',
'desc': 'brain',
'space': 'T1w',
'suffix': 'mask',
},
't1w_aseg': {'datatype': 'anat', 'desc': 'aseg', 'space': 'T1w', 'suffix': 'dseg'},
't2w_mask': {
'datatype': 'anat',
'desc': 'brain',
'space': 'T2w',
'suffix': 'mask',
},
't2w_aseg': {
'datatype': 'anat',
'desc': 'aseg',
'space': 'T2w',
'suffix': 'dseg',
},
}


class Derivatives:
"""
A container class for collecting and storing derivatives.

A specification (either dictionary or JSON file) can be used to customize derivatives and
queries.
To populate this class with derivatives, the `populate()` method must first be called.
"""

@dataclass
class BOLDGrouping:
"""This class is used to facilitate the grouping of BOLD series."""
def __getattribute__(self, attr):
"""In cases where the spec may change, avoid errors."""
try:
return object.__getattribute__(self, attr)
except AttributeError:
return None

Check warning on line 54 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L53-L54

Added lines #L53 - L54 were not covered by tests

def __init__(self, bids_root: Path | str, spec: dict | Path | str | None = None, **args):
self.bids_root = Path(bids_root)
self.spec = _spec
if spec is not None:
if not isinstance(spec, dict):
spec: dict = json.loads(Path(spec).read_text())
self.spec = spec

Check warning on line 62 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L60-L62

Added lines #L60 - L62 were not covered by tests

self.names = set(self.spec.keys())
self.references = {name: None for name in self.names}
for name in self.names:
setattr(self, name, None)

def __repr__(self):
return '\n'.join([name for name in self.names if getattr(self, name)])

Check warning on line 70 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L70

Added line #L70 was not covered by tests

def __contains__(self, val: str):
return val in self.names

Check warning on line 73 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L73

Added line #L73 was not covered by tests

def __bool__(self):
return any(getattr(self, name) for name in self.names)

Check warning on line 76 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L76

Added line #L76 was not covered by tests

def populate(
self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL
) -> None:
"""Query a derivatives directory and populate values and references based on the spec."""
layout = BIDSLayout(deriv_path, validate=False)
for name, query in self.spec.items():
items = layout.get(
subject=subject_id,
session=session_id,
extension=['.nii', '.nii.gz'],
**query,
)
if not items or len(items) > 1:
warnings.warn(f"Could not find {name}")
continue
item = items[0]

# Skip if derivative does not have valid metadata
metadata = item.get_metadata()
if not metadata or not (reference := metadata.get('SpatialReference')):
warnings.warn(f"No metadata found for {item}")
continue

Check warning on line 99 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L98-L99

Added lines #L98 - L99 were not covered by tests
if isinstance(reference, list):
if len(reference) > 1:
warnings.warn(f"Multiple reference found: {reference}")
continue
reference = reference[0]

Check warning on line 104 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L101-L104

Added lines #L101 - L104 were not covered by tests

reference = self.bids_root / reference
if not self.validate(item.path, str(reference)):
warnings.warn(f"Validation failed between: {item.path} and {reference}")
continue

Check warning on line 109 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L108-L109

Added lines #L108 - L109 were not covered by tests

setattr(self, name, Path(item.path))
self.references[name] = reference

session: Union[str, None]
pe_dir: str
readout: float
multiecho_id: str = None
files: List[IO] = field(default_factory=list)
@property
def mask(self) -> str | None:
return self.t1w_mask or self.t2w_mask

@property
def name(self) -> str:
return f"{self.session}-{self.pe_dir}-{self.readout}-{self.multiecho_id}"
def aseg(self) -> str | None:
return self.t1w_aseg or self.t2w_aseg

def add_file(self, fl) -> None:
self.files.append(fl)
@staticmethod
def validate(derivative: str, reference: str, atol: float = 1e-5) -> bool:
anat = nb.load(reference)
expected_ort = nb.aff2axcodes(anat.affine)
img = nb.load(derivative)
if nb.aff2axcodes(img.affine) != expected_ort:
return False

Check warning on line 128 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L128

Added line #L128 was not covered by tests
if img.shape != anat.shape or not np.allclose(anat.affine, img.affine, atol=atol):
return False

Check warning on line 130 in nibabies/utils/bids.py

View check run for this annotation

Codecov / codecov/patch

nibabies/utils/bids.py#L130

Added line #L130 was not covered by tests
return True


def write_bidsignore(deriv_dir):
Expand Down Expand Up @@ -221,55 +324,11 @@
print("bids-validator does not appear to be installed", file=sys.stderr)


def collect_precomputed_derivatives(layout, subject_id, derivatives_filters=None):
"""
Query and collect precomputed derivatives.

This function is used to determine which workflow steps can be skipped,
based on the files found.
"""

deriv_queries = {
'anat_mask': {
'datatype': 'anat',
'desc': 'brain',
'space': 'orig',
'suffix': 'mask',
},
'anat_aseg': {
'datatype': 'anat',
'desc': 'aseg',
'space': 'orig',
'suffix': 'dseg',
},
}
if derivatives_filters is not None:
deriv_queries.update(derivatives_filters)

derivatives = {}
for deriv, query in deriv_queries.items():
res = layout.get(
scope='derivatives',
subject=subject_id,
extension=['.nii', '.nii.gz'],
return_type="filename",
**query,
)
if not res:
continue
if len(res) > 1: # Some queries may want multiple results
raise Exception(
f"When searching for <{deriv}>, found multiple results: {[f.path for f in res]}"
)
derivatives[deriv] = res[0]
return derivatives


def parse_bids_for_age_months(
bids_root: Union[str, Path],
bids_root: str | Path,
subject_id: str,
session_id: Optional[str] = None,
) -> Optional[int]:
session_id: str | None = None,
) -> int | None:
"""
Given a BIDS root, query the BIDS metadata files for participant age, in months.

Expand All @@ -295,8 +354,8 @@


def _get_age_from_tsv(
bids_tsv: Path, level: Literal['session', 'participant'], key: str
) -> Optional[int]:
bids_tsv: Path, level: ty.Literal['session', 'participant'], key: str
) -> int | None:
import pandas as pd

df = pd.read_csv(str(bids_tsv), sep='\t')
Expand Down
Empty file.
113 changes: 113 additions & 0 deletions nibabies/utils/tests/test_bids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import json
import typing as ty
from pathlib import Path

import pytest

from nibabies.utils import bids


def _create_nifti(filename: str) -> str:
import nibabel as nb
import numpy as np

data = np.zeros((4, 4, 4), dtype='int8')
nb.Nifti1Image(data, np.eye(4)).to_filename(filename)
return filename


def _create_bids_dir(root_path: Path):
if not root_path.exists():
root_path.mkdir()
anat_dir = root_path / 'sub-01' / 'anat'
anat_dir.mkdir(parents=True)
_create_nifti(str(anat_dir / 'sub-01_T1w.nii.gz'))
_create_nifti(str(anat_dir / 'sub-01_T2w.nii.gz'))


def _create_bids_derivs(
root_path: Path,
*,
t1w_mask: bool = False,
t1w_aseg: bool = False,
t2w_mask: bool = False,
t2w_aseg: bool = False,
):
if not root_path.exists():
root_path.mkdir()
(root_path / 'dataset_description.json').write_text(
json.dumps(
{'Name': 'Derivatives Test', 'BIDSVersion': '1.8.0', 'DatasetType': 'derivative'}
)
)
anat_dir = root_path / 'sub-01' / 'anat'
anat_dir.mkdir(parents=True)

def _create_deriv(name: str, modality: ty.Literal['t1w', 't2w']):
if modality == 't1w':
reference = 'sub-01/anat/sub-01_T1w.nii.gz'
elif modality == 't2w':
reference = 'sub-01/anat/sub-01_T2w.nii.gz'

_create_nifti(str((anat_dir / name).with_suffix('.nii.gz')))
(anat_dir / name).with_suffix('.json').write_text(
json.dumps({'SpatialReference': reference})
)

if t1w_mask:
_create_deriv('sub-01_space-T1w_desc-brain_mask', 't1w')
if t1w_aseg:
_create_deriv('sub-01_space-T1w_desc-aseg_dseg', 't1w')
if t2w_mask:
_create_deriv('sub-01_space-T2w_desc-brain_mask', 't2w')
if t2w_aseg:
_create_deriv('sub-01_space-T2w_desc-aseg_dseg', 't2w')


@pytest.mark.parametrize(
't1w_mask,t1w_aseg,t2w_mask,t2w_aseg,mask,aseg',
[
(True, True, False, False, 't1w_mask', 't1w_aseg'),
(True, True, True, True, 't1w_mask', 't1w_aseg'),
(False, False, True, True, 't2w_mask', 't2w_aseg'),
(True, False, False, True, 't1w_mask', 't2w_aseg'),
(False, False, False, False, None, None),
],
)
def test_derivatives(
tmp_path: Path,
t1w_mask: bool,
t1w_aseg: bool,
t2w_mask: bool,
t2w_aseg: bool,
mask: str | None,
aseg: str | None,
):
bids_dir = tmp_path / 'bids'
_create_bids_dir(bids_dir)
deriv_dir = tmp_path / 'derivatives'
_create_bids_derivs(
deriv_dir, t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_mask=t2w_mask, t2w_aseg=t2w_aseg
)

derivatives = bids.Derivatives(bids_dir)
assert derivatives.mask is None
assert derivatives.t1w_mask is None
assert derivatives.t2w_mask is None
assert derivatives.aseg is None
assert derivatives.t1w_aseg is None
assert derivatives.t2w_aseg is None

derivatives.populate(deriv_dir, subject_id='01')
if mask:
assert derivatives.mask == getattr(derivatives, mask)
assert derivatives.references[mask]
else:
assert derivatives.mask is None
if aseg:
assert derivatives.aseg == getattr(derivatives, aseg)
assert derivatives.references[aseg]
else:
assert derivatives.aseg == None
53 changes: 0 additions & 53 deletions nibabies/utils/validation.py

This file was deleted.

Loading