diff --git a/nibabies/utils/bids.py b/nibabies/utils/bids.py index 0a10d5a7..765ca47e 100644 --- a/nibabies/utils/bids.py +++ b/nibabies/utils/bids.py @@ -1,31 +1,134 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Utilities to handle BIDS inputs.""" +from __future__ import annotations + import json import os import sys +import typing as ty import warnings -from dataclasses import dataclass, field from pathlib import Path -from typing import IO, List, Literal, Optional, Union +import nibabel as nb +import numpy as np +from bids.layout import BIDSLayout, Query + +_spec: dict = { + 't1w_mask': { + 'datatype': 'anat', + 'desc': 'brain', + 'space': 'T1w', + 'suffix': 'mask', + }, + 't1w_aseg': {'datatype': 'anat', 'desc': 'aseg', 'space': 'T1w', 'suffix': 'dseg'}, + 't2w_mask': { + 'datatype': 'anat', + 'desc': 'brain', + 'space': 'T2w', + 'suffix': 'mask', + }, + 't2w_aseg': { + 'datatype': 'anat', + 'desc': 'aseg', + 'space': 'T2w', + 'suffix': 'dseg', + }, +} + + +class Derivatives: + """ + A container class for collecting and storing derivatives. + + A specification (either dictionary or JSON file) can be used to customize derivatives and + queries. + To populate this class with derivatives, the `populate()` method must first be called. + """ -@dataclass -class BOLDGrouping: - """This class is used to facilitate the grouping of BOLD series.""" + def __getattribute__(self, attr): + """In cases where the spec may change, avoid errors.""" + try: + return object.__getattribute__(self, attr) + except AttributeError: + return None + + def __init__(self, bids_root: Path | str, spec: dict | Path | str | None = None, **args): + self.bids_root = Path(bids_root) + self.spec = _spec + if spec is not None: + if not isinstance(spec, dict): + spec: dict = json.loads(Path(spec).read_text()) + self.spec = spec + + self.names = set(self.spec.keys()) + self.references = {name: None for name in self.names} + for name in self.names: + setattr(self, name, None) + + def __repr__(self): + return '\n'.join([name for name in self.names if getattr(self, name)]) + + def __contains__(self, val: str): + return val in self.names + + def __bool__(self): + return any(getattr(self, name) for name in self.names) + + def populate( + self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL + ) -> None: + """Query a derivatives directory and populate values and references based on the spec.""" + layout = BIDSLayout(deriv_path, validate=False) + for name, query in self.spec.items(): + items = layout.get( + subject=subject_id, + session=session_id, + extension=['.nii', '.nii.gz'], + **query, + ) + if not items or len(items) > 1: + warnings.warn(f"Could not find {name}") + continue + item = items[0] + + # Skip if derivative does not have valid metadata + metadata = item.get_metadata() + if not metadata or not (reference := metadata.get('SpatialReference')): + warnings.warn(f"No metadata found for {item}") + continue + if isinstance(reference, list): + if len(reference) > 1: + warnings.warn(f"Multiple reference found: {reference}") + continue + reference = reference[0] + + reference = self.bids_root / reference + if not self.validate(item.path, str(reference)): + warnings.warn(f"Validation failed between: {item.path} and {reference}") + continue + + setattr(self, name, Path(item.path)) + self.references[name] = reference - session: Union[str, None] - pe_dir: str - readout: float - multiecho_id: str = None - files: List[IO] = field(default_factory=list) + @property + def mask(self) -> str | None: + return self.t1w_mask or self.t2w_mask @property - def name(self) -> str: - return f"{self.session}-{self.pe_dir}-{self.readout}-{self.multiecho_id}" + def aseg(self) -> str | None: + return self.t1w_aseg or self.t2w_aseg - def add_file(self, fl) -> None: - self.files.append(fl) + @staticmethod + def validate(derivative: str, reference: str, atol: float = 1e-5) -> bool: + anat = nb.load(reference) + expected_ort = nb.aff2axcodes(anat.affine) + img = nb.load(derivative) + if nb.aff2axcodes(img.affine) != expected_ort: + return False + if img.shape != anat.shape or not np.allclose(anat.affine, img.affine, atol=atol): + return False + return True def write_bidsignore(deriv_dir): @@ -221,55 +324,11 @@ def validate_input_dir(exec_env, bids_dir, participant_label): print("bids-validator does not appear to be installed", file=sys.stderr) -def collect_precomputed_derivatives(layout, subject_id, derivatives_filters=None): - """ - Query and collect precomputed derivatives. - - This function is used to determine which workflow steps can be skipped, - based on the files found. - """ - - deriv_queries = { - 'anat_mask': { - 'datatype': 'anat', - 'desc': 'brain', - 'space': 'orig', - 'suffix': 'mask', - }, - 'anat_aseg': { - 'datatype': 'anat', - 'desc': 'aseg', - 'space': 'orig', - 'suffix': 'dseg', - }, - } - if derivatives_filters is not None: - deriv_queries.update(derivatives_filters) - - derivatives = {} - for deriv, query in deriv_queries.items(): - res = layout.get( - scope='derivatives', - subject=subject_id, - extension=['.nii', '.nii.gz'], - return_type="filename", - **query, - ) - if not res: - continue - if len(res) > 1: # Some queries may want multiple results - raise Exception( - f"When searching for <{deriv}>, found multiple results: {[f.path for f in res]}" - ) - derivatives[deriv] = res[0] - return derivatives - - def parse_bids_for_age_months( - bids_root: Union[str, Path], + bids_root: str | Path, subject_id: str, - session_id: Optional[str] = None, -) -> Optional[int]: + session_id: str | None = None, +) -> int | None: """ Given a BIDS root, query the BIDS metadata files for participant age, in months. @@ -295,8 +354,8 @@ def parse_bids_for_age_months( def _get_age_from_tsv( - bids_tsv: Path, level: Literal['session', 'participant'], key: str -) -> Optional[int]: + bids_tsv: Path, level: ty.Literal['session', 'participant'], key: str +) -> int | None: import pandas as pd df = pd.read_csv(str(bids_tsv), sep='\t') diff --git a/nibabies/utils/tests/__init__.py b/nibabies/utils/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nibabies/utils/tests/test_bids.py b/nibabies/utils/tests/test_bids.py new file mode 100644 index 00000000..c46fe3e8 --- /dev/null +++ b/nibabies/utils/tests/test_bids.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import json +import typing as ty +from pathlib import Path + +import pytest + +from nibabies.utils import bids + + +def _create_nifti(filename: str) -> str: + import nibabel as nb + import numpy as np + + data = np.zeros((4, 4, 4), dtype='int8') + nb.Nifti1Image(data, np.eye(4)).to_filename(filename) + return filename + + +def _create_bids_dir(root_path: Path): + if not root_path.exists(): + root_path.mkdir() + anat_dir = root_path / 'sub-01' / 'anat' + anat_dir.mkdir(parents=True) + _create_nifti(str(anat_dir / 'sub-01_T1w.nii.gz')) + _create_nifti(str(anat_dir / 'sub-01_T2w.nii.gz')) + + +def _create_bids_derivs( + root_path: Path, + *, + t1w_mask: bool = False, + t1w_aseg: bool = False, + t2w_mask: bool = False, + t2w_aseg: bool = False, +): + if not root_path.exists(): + root_path.mkdir() + (root_path / 'dataset_description.json').write_text( + json.dumps( + {'Name': 'Derivatives Test', 'BIDSVersion': '1.8.0', 'DatasetType': 'derivative'} + ) + ) + anat_dir = root_path / 'sub-01' / 'anat' + anat_dir.mkdir(parents=True) + + def _create_deriv(name: str, modality: ty.Literal['t1w', 't2w']): + if modality == 't1w': + reference = 'sub-01/anat/sub-01_T1w.nii.gz' + elif modality == 't2w': + reference = 'sub-01/anat/sub-01_T2w.nii.gz' + + _create_nifti(str((anat_dir / name).with_suffix('.nii.gz'))) + (anat_dir / name).with_suffix('.json').write_text( + json.dumps({'SpatialReference': reference}) + ) + + if t1w_mask: + _create_deriv('sub-01_space-T1w_desc-brain_mask', 't1w') + if t1w_aseg: + _create_deriv('sub-01_space-T1w_desc-aseg_dseg', 't1w') + if t2w_mask: + _create_deriv('sub-01_space-T2w_desc-brain_mask', 't2w') + if t2w_aseg: + _create_deriv('sub-01_space-T2w_desc-aseg_dseg', 't2w') + + +@pytest.mark.parametrize( + 't1w_mask,t1w_aseg,t2w_mask,t2w_aseg,mask,aseg', + [ + (True, True, False, False, 't1w_mask', 't1w_aseg'), + (True, True, True, True, 't1w_mask', 't1w_aseg'), + (False, False, True, True, 't2w_mask', 't2w_aseg'), + (True, False, False, True, 't1w_mask', 't2w_aseg'), + (False, False, False, False, None, None), + ], +) +def test_derivatives( + tmp_path: Path, + t1w_mask: bool, + t1w_aseg: bool, + t2w_mask: bool, + t2w_aseg: bool, + mask: str | None, + aseg: str | None, +): + bids_dir = tmp_path / 'bids' + _create_bids_dir(bids_dir) + deriv_dir = tmp_path / 'derivatives' + _create_bids_derivs( + deriv_dir, t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_mask=t2w_mask, t2w_aseg=t2w_aseg + ) + + derivatives = bids.Derivatives(bids_dir) + assert derivatives.mask is None + assert derivatives.t1w_mask is None + assert derivatives.t2w_mask is None + assert derivatives.aseg is None + assert derivatives.t1w_aseg is None + assert derivatives.t2w_aseg is None + + derivatives.populate(deriv_dir, subject_id='01') + if mask: + assert derivatives.mask == getattr(derivatives, mask) + assert derivatives.references[mask] + else: + assert derivatives.mask is None + if aseg: + assert derivatives.aseg == getattr(derivatives, aseg) + assert derivatives.references[aseg] + else: + assert derivatives.aseg == None diff --git a/nibabies/utils/validation.py b/nibabies/utils/validation.py deleted file mode 100644 index 10ee949d..00000000 --- a/nibabies/utils/validation.py +++ /dev/null @@ -1,53 +0,0 @@ -import sys - -import nibabel as nb -import numpy as np - - -def validate_t1w_derivatives(t1w_template, *, anat_mask=None, anat_aseg=None, atol=1e-5): - """ - Validate anatomical derivatives. - This function compares the input T1w's orientation and physical space to each derivative. - - Parameters - ---------- - t1w_template : str - T1w template - anat_mask : str or None - Precomputed anatomical brain mask - anat_aseg : str or None - Precomputed anatomical segmentations - atol : float - Absolute error tolerance between image origins - - Returns - ------- - validated : dict - A dictionary composed of derivative keys and validated filename values. - Derivatives that failed validation will not be included. - """ - - validated = {} - # T1w information - t1w = nb.load(t1w_template) - expected_ort = nb.aff2axcodes(t1w.affine) - - # Ensure orientation - for name, deriv_fl in zip(('anat_mask', 'anat_aseg'), (anat_mask, anat_aseg)): - if deriv_fl is None: - continue - img = nb.load(deriv_fl) - if nb.aff2axcodes(img.affine) != expected_ort: - print( - f"Orientation mismatch between {name} <{deriv_fl}> and T1w <{t1w_template}>", - file=sys.stderr, - ) - continue - if img.shape != t1w.shape or not np.allclose(t1w.affine, img.affine, atol=atol): - print( - f"Physical space mismatch between {name} <{deriv_fl}> and T1w <{t1w_template}>", - file=sys.stderr, - ) - continue - validated[name] = deriv_fl - return validated diff --git a/nibabies/workflows/anatomical/base.py b/nibabies/workflows/anatomical/base.py index 36ab2ab2..dc475b43 100644 --- a/nibabies/workflows/anatomical/base.py +++ b/nibabies/workflows/anatomical/base.py @@ -1,7 +1,8 @@ """Base anatomical preprocessing.""" -import warnings +from __future__ import annotations + +import typing as ty from pathlib import Path -from typing import Literal, Optional, Union from nipype.interfaces import utility as niu from nipype.pipeline import engine as pe @@ -10,27 +11,30 @@ from ... import config +if ty.TYPE_CHECKING: + from nibabies.utils.bids import Derivatives + def init_infant_anat_wf( *, - age_months: Optional[int], + age_months: int, ants_affine_init: bool, t1w: list, t2w: list, anat_modality: str, - bids_root: Optional[Union[str, Path]], - existing_derivatives: dict, + bids_root: str | Path, + derivatives: Derivatives, freesurfer: bool, - hires: Optional[bool], + hires: bool | None, longitudinal: bool, omp_nthreads: int, - output_dir: Union[str, Path], - segmentation_atlases: Optional[Union[str, Path]], + output_dir: str | Path, + segmentation_atlases: str | Path | None, skull_strip_mode: str, skull_strip_template: Reference, sloppy: bool, - spaces: Optional[SpatialReferences], - cifti_output: Optional[Literal['91k', '170k']], + spaces: SpatialReferences | None, + cifti_output: ty.Literal['91k', '170k'] | None, name: str = "infant_anat_wf", ) -> LiterateWorkflow: """ @@ -93,7 +97,7 @@ def init_infant_anat_wf( init_coreg_report_wf, ) from .preproc import init_anat_preproc_wf - from .registration import init_coregistration_wf + from .registration import init_coregister_derivatives_wf, init_coregistration_wf from .segmentation import init_anat_segmentations_wf from .surfaces import init_anat_ribbon_wf from .template import init_anat_template_wf @@ -102,28 +106,9 @@ def init_infant_anat_wf( num_t1w = len(t1w) if t1w else 0 num_t2w = len(t2w) if t2w else 0 - precomp_mask = existing_derivatives.get("anat_mask") - precomp_aseg = existing_derivatives.get("anat_aseg") - - # verify derivatives are relatively similar to T1w - if precomp_mask or precomp_aseg: - if num_t1w > 1: - precomp_mask = None - precomp_aseg = None - warnings.warn( - "Multiple T1w files were found; precomputed derivatives will not be used." - ) - - else: - from ...utils.validation import validate_t1w_derivatives - - validated_derivatives = ( - validate_t1w_derivatives( # compare derivatives to the first T1w - t1w[0], anat_mask=precomp_mask, anat_aseg=precomp_aseg - ) - ) - precomp_mask = validated_derivatives.get("anat_mask") - precomp_aseg = validated_derivatives.get("anat_aseg") + # Expected derivatives: Prioritize T1w space if available, otherwise fall back to T2w + deriv_mask = derivatives.mask + deriv_aseg = derivatives.aseg wf = LiterateWorkflow(name=name) desc = f"""\n @@ -186,7 +171,7 @@ def init_infant_anat_wf( desc += ( "A previously computed mask was used to skull-strip the anatomical image." - if precomp_mask + if deriv_mask else """\ The T1w-reference was then skull-stripped with a modified implementation of the `antsBrainExtraction.sh` workflow (from ANTs), using {skullstrip_tpl} @@ -214,13 +199,49 @@ def init_infant_anat_wf( cifti_output=cifti_output, ) - # Multiple anatomical files -> generate average reference + # Derivatives used based on the following truth table: + # |--------|--------|---------------------------------|------------------| + # | Has T1 | Has T2 | M-CRIB-S surface reconstruction | Derivatives Used | + # |--------|--------|---------------------------------|------------------| + # | Yes | No | No | T1 | + # | Yes | Yes | No | T1 | + # | No | Yes | No | T2 | + # | Yes | Yes | Yes | T2 | + + recon_method = config.workflow.surface_recon_method + t1w_mask = bool(derivatives.t1w_mask) + t1w_aseg = bool(derivatives.t1w_aseg) + t2w_mask = bool(derivatives.t2w_mask) + t2w_aseg = bool(derivatives.t2w_aseg) + + # The T2 derivatives are only prioritized first if MCRIBS reconstruction is to be used. + if recon_method == "mcribs": + if t2w_aseg: + t1w_aseg = False + if t2w_mask: + t1w_mask = False + # Otherwise, prioritize T1 derivatives + if t1w_mask: + t2w_mask = False + if t1w_aseg: + t2w_aseg = False + + config.loggers.workflow.info( + "Derivatives used:\n\t\n\t\n\t\n\t", + t1w_mask, + t1w_aseg, + t2w_mask, + t2w_aseg, + ) + t1w_template_wf = init_anat_template_wf( contrast="T1w", num_files=num_t1w, longitudinal=longitudinal, omp_nthreads=omp_nthreads, sloppy=sloppy, + has_mask=t1w_mask, + has_aseg=t1w_aseg, name="t1w_template_wf", ) @@ -230,16 +251,14 @@ def init_infant_anat_wf( longitudinal=longitudinal, omp_nthreads=omp_nthreads, sloppy=sloppy, + has_mask=t2w_mask, + has_aseg=t2w_aseg, name="t2w_template_wf", ) # Clean up each anatomical template # Denoise, INU, + Clipping - t1w_preproc_wf = init_anat_preproc_wf( - precomputed_mask=bool(precomp_mask), - precomputed_aseg=bool(precomp_aseg), - name="t1w_preproc_wf", - ) + t1w_preproc_wf = init_anat_preproc_wf(name="t1w_preproc_wf") t2w_preproc_wf = init_anat_preproc_wf(name="t2w_preproc_wf") if skull_strip_mode != "force": @@ -249,7 +268,8 @@ def init_infant_anat_wf( omp_nthreads=omp_nthreads, sloppy=sloppy, debug="registration" in config.execution.debug, - precomputed_mask=bool(precomp_mask), + t1w_mask=t1w_mask, + probmap=not t2w_mask, ) coreg_report_wf = init_coreg_report_wf( output_dir=output_dir, @@ -261,7 +281,7 @@ def init_infant_anat_wf( template_dir=segmentation_atlases, sloppy=sloppy, omp_nthreads=omp_nthreads, - precomp_aseg=precomp_aseg, + precomp_aseg=bool(derivatives.aseg), ) # Spatial normalization (requires segmentation) @@ -347,15 +367,44 @@ def init_infant_anat_wf( ]), ]) - if precomp_mask: - # Ensure the mask is conformed along with the T1w - t1w_preproc_wf.inputs.inputnode.in_mask = precomp_mask + # Workflow to move derivatives between T1w/T2w spaces + # May not be used, but define in case necessary. + coreg_deriv_wf = init_coregister_derivatives_wf( + t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_aseg=t2w_aseg + ) + deriv_buffer = pe.Node( + niu.IdentityInterface(fields=['t2w_mask', 't1w_aseg', 't2w_aseg']), + name='deriv_buffer', + ) + if derivatives: + wf.connect([ + (coregistration_wf, coreg_deriv_wf, [('outputnode.t1w2t2w_xfm', 'inputnode.t1w2t2w_xfm')]), + (t1w_preproc_wf, coreg_deriv_wf, [('outputnode.anat_preproc', 'inputnode.t1w_ref')]), + (t2w_preproc_wf, coreg_deriv_wf, [('outputnode.anat_preproc', 'inputnode.t2w_ref')]), + ]) + + # Derivative mask is present + if t1w_mask: + t1w_template_wf.inputs.inputnode.anat_mask = derivatives.t1w_mask + t1w_template_wf.inputs.inputnode.mask_reference = derivatives.references['t1w_mask'] # fmt:off wf.connect([ - (t1w_preproc_wf, coregistration_wf, [("outputnode.anat_mask", "inputnode.in_mask")]), - (t2w_preproc_wf, coregistration_wf, [("outputnode.anat_preproc", "inputnode.in_t2w")]) + (t1w_template_wf, coregistration_wf, [('outputnode.anat_mask', 'inputnode.in_mask')]), + (t2w_preproc_wf, coregistration_wf, [('outputnode.anat_preproc', 'inputnode.in_t2w')]), + (t1w_template_wf, coreg_deriv_wf, [('outputnode.anat_mask', 'inputnode.t1w_mask')]), + (coreg_deriv_wf, deriv_buffer, [('outputnode.t2w_mask', 't2w_mask')]) ]) # fmt:on + elif t2w_mask: + t2w_template_wf.inputs.inputnode.anat_mask = derivatives.t2w_mask + t2w_template_wf.inputs.inputnode.mask_reference = derivatives.references['t2w_mask'] + # fmt:on + wf.connect([ + (t2w_template_wf, coregistration_wf, [('outputnode.anat_mask', 'inputnode.in_mask')]), + (t2w_preproc_wf, coregistration_wf, [('outputnode.anat_preproc', 'inputnode.in_t2w')]), + (t2w_template_wf, deriv_buffer, [('outputnode.anat_mask', 't2w_mask')]), + ]) + # fmt:off else: # Run brain extraction on the T2w brain_extraction_wf = init_infant_brain_extraction_wf( @@ -378,63 +427,81 @@ def init_infant_anat_wf( ]) # fmt:on - if precomp_aseg: - # Ensure the segmentation is conformed along with the T1w - t1w_preproc_wf.inputs.inputnode.in_aseg = precomp_aseg - wf.connect(t1w_preproc_wf, "outputnode.anat_aseg", anat_seg_wf, "inputnode.anat_aseg") + # Derivative segmentation is present + if derivatives.aseg: + wf.connect(deriv_buffer, 't1w_aseg', anat_seg_wf, 'inputnode.anat_aseg') + + if t1w_aseg: + t1w_template_wf.inputs.inputnode.anat_aseg = derivatives.t1w_aseg + t1w_template_wf.inputs.inputnode.aseg_reference = derivatives.references['t1w_aseg'] + # fmt:off + wf.connect([ + (t1w_template_wf, deriv_buffer, [('outputnode.anat_aseg', 't1w_aseg')]), + (t1w_template_wf, coreg_deriv_wf, [('outputnode.anat_aseg', 'inputnode.t1w_aseg')]), + (coreg_deriv_wf, deriv_buffer, [('outputnode.t2w_aseg', 't2w_aseg')]), + ]) + # fmt:on + elif t2w_aseg: + t2w_template_wf.inputs.inputnode.anat_aseg = derivatives.t2w_aseg + t2w_template_wf.inputs.inputnode.aseg_reference = derivatives.references['t2w_aseg'] + # fmt:off + wf.connect([ + (t2w_template_wf, deriv_buffer, [('outputnode.anat_aseg', 't2w_aseg')]), + (t2w_template_wf, coreg_deriv_wf, [('outputnode.anat_aseg', 'inputnode.t2w_aseg')]), + (coreg_deriv_wf, deriv_buffer, [('outputnode.t1w_aseg', 't1w_aseg')]), + ]) + # fmt:on if not freesurfer: return wf - if config.workflow.surface_recon_method == 'freesurfer': + if recon_method == 'freesurfer': from smriprep.workflows.surfaces import init_surface_recon_wf surface_recon_wf = init_surface_recon_wf(omp_nthreads=omp_nthreads, hires=hires) - elif config.workflow.surface_recon_method == 'infantfs': + elif recon_method == 'infantfs': from .surfaces import init_infantfs_surface_recon_wf # if running with precomputed aseg, or JLF, pass the aseg along to FreeSurfer - use_aseg = bool(precomp_aseg or segmentation_atlases) + use_aseg = bool(derivatives.aseg or segmentation_atlases) surface_recon_wf = init_infantfs_surface_recon_wf( age_months=age_months, use_aseg=use_aseg, ) - elif config.workflow.surface_recon_method == 'mcribs': + elif recon_method == 'mcribs': from nipype.interfaces.ants import DenoiseImage from .surfaces import init_mcribs_sphere_reg_wf, init_mcribs_surface_recon_wf - # Denoise raw T2w, since using the template / preproc resulted in intersection errors - denoise_raw_t2w = pe.Node( - DenoiseImage(dimension=3, noise_model="Rician"), name='denoise_raw_t2w' + # Denoise template T2w, since using the template / preproc resulted in intersection errors + denoise_t2w = pe.Node( + DenoiseImage(dimension=3, noise_model="Rician"), name='denoise_t2w' ) - + # t2w mask, t2w aseg surface_recon_wf = init_mcribs_surface_recon_wf( omp_nthreads=omp_nthreads, - use_aseg=bool(precomp_aseg), - use_mask=bool(precomp_mask), + use_aseg=bool(derivatives.aseg), # TODO: Incorporate mcribs segmentation + use_mask=bool(derivatives.mask), # TODO: Pass in mask regardless of derivatives mcribs_dir=str(config.execution.mcribs_dir), # Needed to preserve runs ) - # M-CRIB-S to dHCP42week (32k) sphere_reg_wf = init_mcribs_sphere_reg_wf() - # Transformed gives - if precomp_aseg: - surface_recon_wf.inputs.inputnode.ants_segs = precomp_aseg - if precomp_mask: - surface_recon_wf.inputs.inputnode.anat_mask = precomp_mask # fmt:off wf.connect([ - (inputnode, denoise_raw_t2w, [('t2w', 'input_image')]), - (denoise_raw_t2w, surface_recon_wf, [('output_image', 'inputnode.t2w')]), + (t2w_template_wf, denoise_t2w, [('outputnode.anat_ref', 'input_image')]), + (denoise_t2w, surface_recon_wf, [('output_image', 'inputnode.t2w')]), ]) # fmt:on + if derivatives.aseg: + wf.connect(deriv_buffer, 't2w_aseg', surface_recon_wf, 'inputnode.ants_segs') + if derivatives.mask: + wf.connect(deriv_buffer, 't2w_mask', surface_recon_wf, 'inputnode.anat_mask') else: raise NotImplementedError - if config.workflow.surface_recon_method in ('freesurfer', 'infantfs'): + if recon_method in ('freesurfer', 'infantfs'): from smriprep.workflows.surfaces import init_sphere_reg_wf # fsaverage to fsLR @@ -516,7 +583,7 @@ def init_infant_anat_wf( init_anat_fsLR_resampling_wf, ) - is_mcribs = config.workflow.surface_recon_method == "mcribs" + is_mcribs = recon_method == "mcribs" # handles morph_grayords_wf anat_fsLR_resampling_wf = init_anat_fsLR_resampling_wf(cifti_output, mcribs=is_mcribs) anat_derivatives_wf.get_node('inputnode').inputs.cifti_density = cifti_output diff --git a/nibabies/workflows/anatomical/preproc.py b/nibabies/workflows/anatomical/preproc.py index 5fa90849..1af9d311 100644 --- a/nibabies/workflows/anatomical/preproc.py +++ b/nibabies/workflows/anatomical/preproc.py @@ -6,8 +6,6 @@ def init_anat_preproc_wf( *, bspline_fitting_distance: int = 200, - precomputed_mask: bool = False, - precomputed_aseg: bool = False, name: str = "anat_preproc_wf", ) -> LiterateWorkflow: """Polish up raw anatomical data. @@ -32,15 +30,15 @@ def init_anat_preproc_wf( """ from nipype.interfaces.ants import DenoiseImage, N4BiasFieldCorrection from niworkflows.interfaces.header import ValidateImage - from niworkflows.interfaces.nibabel import IntensityClip, RegridToZooms + from niworkflows.interfaces.nibabel import IntensityClip wf = LiterateWorkflow(name=name) inputnode = pe.Node( - niu.IdentityInterface(fields=["in_anat", "in_mask", "in_aseg"]), + niu.IdentityInterface(fields=["in_anat"]), name="inputnode", ) outputnode = pe.Node( - niu.IdentityInterface(fields=["anat_preproc", "anat_mask", "anat_aseg"]), + niu.IdentityInterface(fields=["anat_preproc"]), name="outputnode", ) @@ -61,32 +59,8 @@ def init_anat_preproc_wf( ), name="n4_correct", ) - final_clip = pe.Node(IntensityClip(p_min=5.0, p_max=99.5), name="final_clip") - if precomputed_mask: - validate_mask = pe.Node(ValidateImage(), name="validate_mask") - regrid_mask = pe.Node(RegridToZooms(), name="regrid_mask") - # fmt:off - wf.connect([ - (inputnode, validate_mask, [("in_mask", "in_file")]), - (validate_mask, regrid_mask, [("out_file", "in_file")]), - (final_clip, regrid_mask, [(("out_file", _get_zooms), "zooms")]), - (regrid_mask, outputnode, [("out_file", "anat_mask")]), - ]) - # fmt:on - if precomputed_aseg: - validate_aseg = pe.Node(ValidateImage(), name="validate_aseg") - regrid_aseg = pe.Node(RegridToZooms(), name="regrid_aseg") - # fmt:off - wf.connect([ - (inputnode, validate_aseg, [("in_aseg", "in_file")]), - (validate_aseg, regrid_aseg, [("out_file", "in_file")]), - (final_clip, regrid_aseg, [(("out_file", _get_zooms), "zooms")]), - (regrid_aseg, outputnode, [("out_file", "anat_aseg")]), - ]) - # fmt:on - # fmt:off wf.connect([ (inputnode, validate, [("in_anat", "in_file")]), @@ -98,9 +72,3 @@ def init_anat_preproc_wf( ]) # fmt:on return wf - - -def _get_zooms(in_file): - import nibabel as nb - - return tuple(nb.load(in_file).header.get_zooms()[:3]) diff --git a/nibabies/workflows/anatomical/registration.py b/nibabies/workflows/anatomical/registration.py index 6491eafa..f2f25cbe 100644 --- a/nibabies/workflows/anatomical/registration.py +++ b/nibabies/workflows/anatomical/registration.py @@ -1,21 +1,22 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Within-baby registration of a T1w into a T2w image.""" -from typing import Optional +from __future__ import annotations from nipype.interfaces import utility as niu from nipype.pipeline import engine as pe -from pkg_resources import resource_filename as pkgr_fn +from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms def init_coregistration_wf( *, bspline_fitting_distance: int = 200, mem_gb: float = 3.0, - omp_nthreads: Optional[int] = None, + omp_nthreads: int | None = None, sloppy: bool = False, debug: bool = False, - precomputed_mask: bool = False, + t1w_mask: bool = False, + probmap: bool = True, name: str = "coregistration_wf", ): """ @@ -52,11 +53,12 @@ def init_coregistration_wf( Run in *sloppy* mode. debug : :obj:`bool` Produce intermediate registration files - precomputed_mask : :obj:`bool` + t1w_mask : :obj:`bool` A precomputed mask for the T1w is available. In this case, generate a quick mask to assist in coregistration, but use the precomputed mask as the final output. - + probmap: :obj:`bool` + A probabilistic brainmask is present in T2w space. Inputs ------ @@ -66,8 +68,8 @@ def init_coregistration_wf( The preprocessed input T2w image (Denoising/INU/Clipping) in_mask : :obj:`str` The brainmask. - If `precomputed_mask` is False, will be in T2w space. - If `precomputed_mask` is True, will be in T1w space. + If `t1w_mask` is False, will be in T2w space. + If `t1w_mask` is True, will be in T1w space. in_probmap : :obj:`str` The probabilistic brainmask, as obtained in T2w space. @@ -172,7 +174,7 @@ def init_coregistration_wf( ]) # fmt: on - if precomputed_mask: + if t1w_mask: # The input mask is already in T1w space. # Generate a quick, rough mask of the T2w to be used to facilitate co-registration. from sdcflows.interfaces.brainmask import BrainExtraction @@ -187,17 +189,19 @@ def init_coregistration_wf( (inputnode, outputnode, [("in_mask", "t1w_mask")]), ]) # fmt:on - else: + return workflow + + if probmap: # The T2w mask from the brain extraction workflow will be mapped to T1w space map_mask = pe.Node(ApplyTransforms(interpolation="Gaussian"), name="map_mask", mem_gb=1) thr_mask = pe.Node(Binarize(thresh_low=0.80), name="thr_mask") # fmt:off workflow.connect([ + (inputnode, reg_mask, [("in_mask", "in_file")]), + (inputnode, refine_mask, [("in_mask", "in_file")]), (inputnode, map_mask, [ ("in_t1w", "reference_image"), ("in_probmap", "input_image")]), - (inputnode, reg_mask, [("in_mask", "in_file")]), - (inputnode, refine_mask, [("in_mask", "in_file")]), (coreg, map_mask, [ ("reverse_transforms", "transforms"), ("reverse_invert_flags", "invert_transform_flags")]), @@ -207,4 +211,77 @@ def init_coregistration_wf( (thr_mask, apply_mask, [("out_mask", "in_mask")]), ]) # fmt:on + return workflow + + # A precomputed T2w mask was provided + map_precomp_mask = pe.Node( + ApplyTransforms(interpolation="MultiLabel"), name='map_precomp_mask' + ) + # fmt:off + workflow.connect([ + (inputnode, reg_mask, [("in_mask", "in_file")]), + (inputnode, refine_mask, [("in_mask", "in_file")]), + (inputnode, map_precomp_mask, [ + ('in_t1w', 'reference_image'), + ('in_mask', 'input_image')]), + (coreg, map_precomp_mask, [ + ("reverse_transforms", "transforms"), + ("reverse_invert_flags", "invert_transform_flags")]), + (map_precomp_mask, final_n4, [('output_image', 'weight_image')]), + (map_precomp_mask, outputnode, [('output_image', 't1w_mask')]), + (map_precomp_mask, apply_mask, [('output_image', 'in_mask')]), + ]) + # fmt:on + return workflow + + +def init_coregister_derivatives_wf( + *, t1w_mask: bool, t1w_aseg: bool, t2w_aseg: bool, name: str = 'coregister_derivatives_wf' +): + """Move derivatives from T1w / T2w space.""" + workflow = pe.Workflow(name=name) + inputnode = pe.Node( + niu.IdentityInterface( + fields=['t1w_ref', 't2w_ref', 't1w2t2w_xfm', 't1w_mask', 't1w_aseg', 't2w_aseg'] + ), + name='inputnode', + ) + outputnode = pe.Node( + niu.IdentityInterface(fields=['t2w_mask', 't1w_aseg', 't2w_aseg']), name='outputnode' + ) + + if t1w_mask: + t1wmask2t2w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t1wmask2t2w') + # fmt:off + workflow.connect([ + (inputnode, t1wmask2t2w, [ + ('t1w_mask', 'input_image'), + ('t1w2t2w_xfm', 'transforms'), + ('t2w_ref', 'reference_image')]), + (t1wmask2t2w, outputnode, [('output_image', 't2w_mask')]) + ]) + # fmt:on + if t1w_aseg: + # fmt:off + t1waseg2t2w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t1waseg2t2w') + workflow.connect([ + (inputnode, t1waseg2t2w, [ + ('t1w_aseg', 'input_image'), + ('t1w2t2w_xfm', 'transforms'), + ('t2w_ref', 'reference_image')]), + (t1waseg2t2w, outputnode, [('output_image', 't2w_aseg')]) + ]) + # fmt:on + if t2w_aseg: + # fmt:off + t2waseg2t1w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t2waseg2t1w') + t2waseg2t1w.inputs.invert_transform_flags = [True, False] + workflow.connect([ + (inputnode, t2waseg2t1w, [ + ('t2w_aseg', 'input_image'), + ('t1w2t2w_xfm', 'transforms'), + ('t1w_ref', 'reference_image')]), + (t2waseg2t1w, outputnode, [('output_image', 't1w_aseg')]) + ]) + # fmt:on return workflow diff --git a/nibabies/workflows/anatomical/template.py b/nibabies/workflows/anatomical/template.py index c7af3e4e..333220e1 100644 --- a/nibabies/workflows/anatomical/template.py +++ b/nibabies/workflows/anatomical/template.py @@ -1,7 +1,10 @@ """Prepare anatomical images for processing.""" +from __future__ import annotations + from nipype.interfaces import utility as niu from nipype.pipeline import engine as pe from niworkflows.engine.workflows import LiterateWorkflow +from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms def init_anat_template_wf( @@ -12,6 +15,8 @@ def init_anat_template_wf( longitudinal: bool = False, bspline_fitting_distance: int = 200, sloppy: bool = False, + has_mask: bool = False, + has_aseg: bool = False, name: str = "anat_template_wf", ) -> LiterateWorkflow: """ @@ -45,6 +50,11 @@ def init_anat_template_wf( ------ anat_files List of structural images + anat_mask + mask_reference + anat_aseg + aseg_reference + Outputs ------- anat_ref @@ -55,12 +65,15 @@ def init_anat_template_wf( List of affine transforms to realign input images to final reference out_report Conformation report + anat_mask + Mask (if provided), resampled to the anatomical reference + anat_aseg + Aseg (if provided), resampled to the anatomical reference """ from nipype.interfaces.ants import N4BiasFieldCorrection from nipype.interfaces.image import Reorient from niworkflows.interfaces.freesurfer import PatchedLTAConvert as LTAConvert from niworkflows.interfaces.freesurfer import StructuralReference - from niworkflows.interfaces.header import ValidateImage from niworkflows.interfaces.images import Conform, TemplateDimensions from niworkflows.interfaces.nibabel import IntensityClip from niworkflows.interfaces.nitransforms import ConcatenateXFMs @@ -80,7 +93,18 @@ def init_anat_template_wf( """ inputnode = pe.Node( - niu.IdentityInterface(fields=["anat_files", "anat_mask", "anat_aseg"]), name="inputnode" + niu.IdentityInterface( + fields=[ + "anat_files", + # Each derivative requires a reference file, which will be used to find which + # transform to apply in the case when multiple runs are present + "anat_mask", + "mask_reference", + "anat_aseg", + "aseg_reference", + ] + ), + name="inputnode", ) outputnode = pe.Node( niu.IdentityInterface( @@ -89,6 +113,8 @@ def init_anat_template_wf( "anat_valid_list", "anat_realign_xfm", "out_report", + "anat_mask", + "anat_aseg", ], ), name="outputnode", @@ -110,6 +136,28 @@ def init_anat_template_wf( ]) # fmt:on + if has_mask: + mask_conform = pe.Node(Conform(), name='mask_conform') + # fmt:off + wf.connect([ + (inputnode, mask_conform, [('anat_mask', 'in_file')]), + (anat_ref_dimensions, mask_conform, [ + ('target_zooms', 'target_zooms'), + ('target_shape', 'target_shape')]), + ]) + # fmt:on + + if has_aseg: + aseg_conform = pe.Node(Conform(), name='aseg_conform') + # fmt:off + wf.connect([ + (inputnode, aseg_conform, [('anat_aseg', 'in_file')]), + (anat_ref_dimensions, aseg_conform, [ + ('target_zooms', 'target_zooms'), + ('target_shape', 'target_shape')]), + ]) + # fmt:on + if num_files == 1: get1st = pe.Node(niu.Select(index=[0]), name="get1st") outputnode.inputs.anat_realign_xfm = [ @@ -122,6 +170,10 @@ def init_anat_template_wf( (get1st, outputnode, [('out', 'anat_ref')]), ]) # fmt:on + if has_mask: + wf.connect(mask_conform, 'out_file', outputnode, 'anat_mask') + if has_aseg: + wf.connect(aseg_conform, 'out_file', outputnode, 'anat_aseg') return wf anat_conform_xfm = pe.MapNode( @@ -180,6 +232,52 @@ def init_anat_template_wf( run_without_submitting=True, ) + if has_mask: + mask_ref_idx = pe.Node( + niu.Function(function=get_reference), name='mask_ref_idx', run_without_submitting=True + ) + mask_xfm = pe.Node(niu.Select(), name='mask_xfm', run_without_submitting=True) + applyxfm_mask = pe.Node( + ApplyTransforms(interpolation='MultiLabel'), name='applyxfm_mask', mem_gb=1 + ) + mask_reorient = pe.Node(Reorient(), name="mask_reorient") + # fmt:off + wf.connect([ + (inputnode, mask_ref_idx, [('mask_reference', 'anat_reference')]), + (anat_ref_dimensions, mask_ref_idx, [('t1w_valid_list', 'anatomicals')]), + (concat_xfms, mask_xfm, [('out_xfm', 'inlist')]), + (mask_ref_idx, mask_xfm, [('out', 'index')]), + (mask_conform, applyxfm_mask, [('out_file', 'input_image')]), + (anat_reorient, applyxfm_mask, [('out_file', 'reference_image')]), + (mask_xfm, applyxfm_mask, [('out', 'transforms')]), + (applyxfm_mask, mask_reorient, [('output_image', 'in_file')]), + (mask_reorient, outputnode, [('out_file', 'anat_mask')]), + ]) + # fmt:on + + if has_aseg: + aseg_ref_idx = pe.Node( + niu.Function(function=get_reference), name='aseg_ref_idx', run_without_submitting=True + ) + aseg_xfm = pe.Node(niu.Select(), name='aseg_xfm', run_without_submitting=True) + applyxfm_aseg = pe.Node( + ApplyTransforms(interpolation='MultiLabel'), name='applyxfm_aseg', mem_gb=1 + ) + aseg_reorient = pe.Node(Reorient(), name="aseg_reorient") + # fmt:off + wf.connect([ + (inputnode, aseg_ref_idx, [('aseg_reference', 'anat_reference')]), + (anat_ref_dimensions, aseg_ref_idx, [('t1w_valid_list', 'anatomicals')]), + (concat_xfms, aseg_xfm, [('out_xfm', 'inlist')]), + (aseg_ref_idx, aseg_xfm, [('out', 'index')]), + (aseg_conform, applyxfm_aseg, [('out_file', 'input_image')]), + (anat_reorient, applyxfm_aseg, [('out_file', 'reference_image')]), + (aseg_xfm, applyxfm_aseg, [('out', 'transforms')]), + (applyxfm_aseg, aseg_reorient, [('output_image', 'in_file')]), + (applyxfm_aseg, outputnode, [('out_file', 'anat_aseg')]), + ]) + # fmt:on + def _set_threads(in_list, maximum): return min(len(in_list), maximum) @@ -204,3 +302,7 @@ def _set_threads(in_list, maximum): ]) # fmt:on return wf + + +def get_reference(anatomicals: list, anat_reference: str) -> int: + return anatomicals.index(anat_reference) diff --git a/nibabies/workflows/base.py b/nibabies/workflows/base.py index 5ce36bf0..eecfc5af 100644 --- a/nibabies/workflows/base.py +++ b/nibabies/workflows/base.py @@ -39,11 +39,12 @@ .. autofunction:: init_single_subject_wf """ +from __future__ import annotations import os import sys +import typing as ty from copy import deepcopy -from typing import Optional from nipype.interfaces import utility as niu from nipype.pipeline import engine as pe @@ -55,6 +56,9 @@ from nibabies.utils.bids import parse_bids_for_age_months from nibabies.workflows.bold import init_func_preproc_wf +if ty.TYPE_CHECKING: + from niworkflows.utils.spaces import SpatialReferences + def init_nibabies_wf(subworkflows_list): """ @@ -155,9 +159,9 @@ def init_nibabies_wf(subworkflows_list): def init_single_subject_wf( subject_id: str, - session_id: Optional[str] = None, - age: Optional[int] = None, - spaces=None, + session_id: str | None = None, + age: int | None = None, + spaces: SpatialReferences | None = None, ): """ Organize the preprocessing pipeline for a single subject, at a single session. @@ -200,6 +204,7 @@ def init_single_subject_wf( from niworkflows.utils.bids import collect_data from niworkflows.utils.spaces import Reference + from ..utils.bids import Derivatives from ..utils.misc import fix_multi_source_name from .anatomical import init_infant_anat_wf @@ -223,7 +228,7 @@ def init_single_subject_wf( subject_data["t2w"] = [] anat_only = config.workflow.anat_only - derivatives = config.execution.derivatives or {} + derivatives = Derivatives(bids_root=config.execution.layout.root) anat_modality = "t1w" if subject_data["t1w"] else "t2w" # Make sure we always go through these two checks if not anat_only and not subject_data["bold"]: @@ -235,16 +240,15 @@ def init_single_subject_wf( ) ) - if derivatives: - from ..utils.bids import collect_precomputed_derivatives - - derivatives = collect_precomputed_derivatives( - config.execution.layout, - subject_id, - derivatives_filters=config.execution.derivatives_filters, - # session_id=None, # TODO: Ensure session is visible at workflow level - ) - config.loggers.workflow.info(f"Found precomputed derivatives: {derivatives}") + if config.execution.derivatives: + for deriv_path in config.execution.derivatives: + config.loggers.workflow.info("Searching for derivatives in %s", deriv_path) + derivatives.populate( + deriv_path, + subject_id, + session_id=session_id, + ) + config.loggers.workflow.info("Found precomputed derivatives %s", derivatives) workflow = Workflow(name=name) workflow.__desc__ = """ @@ -348,7 +352,7 @@ def init_single_subject_wf( t1w=subject_data["t1w"], t2w=subject_data["t2w"], bids_root=config.execution.bids_dir, - existing_derivatives=derivatives, + derivatives=derivatives, freesurfer=config.workflow.run_reconall, hires=config.workflow.hires, longitudinal=config.workflow.longitudinal, @@ -558,7 +562,7 @@ def _prefix(subid): return subid if subid.startswith("sub-") else f"sub-{subid}" -def init_workflow_spaces(execution_spaces, age_months): +def init_workflow_spaces(execution_spaces: SpatialReferences, age_months: int): """ Create output spaces at a per-subworkflow level. diff --git a/wrapper/src/nibabies_wrapper/__main__.py b/wrapper/src/nibabies_wrapper/__main__.py index 8873088b..147a056f 100755 --- a/wrapper/src/nibabies_wrapper/__main__.py +++ b/wrapper/src/nibabies_wrapper/__main__.py @@ -29,7 +29,7 @@ MISSING = """ Image '{}' is missing Would you like to download? [Y/n] """ -PKG_PATH = "/opt/conda/lib/python3.9/site-packages" +PKG_PATH = "/opt/conda/envs/nibabies/lib/python3.10/site-packages" TF_TEMPLATES = ( "MNI152Lin", "MNI152NLin2009cAsym",