Skip to content

Commit

Permalink
fix: refactor the brain extraction workflow
Browse files Browse the repository at this point in the history
This commit:

  - [x] Updates the nodes with pure python interfaces based on nibabel,
    minimizing the need for the new ``copy_header`` of ANTs' nipype
    interfaces.
  - [x] Reorganizes the workflow so that the Atropos refinement is
    completely self contained.

These are the first two steps to address nipreps/smriprep#125.
  • Loading branch information
oesteban committed Sep 16, 2020
1 parent 9c3eb81 commit 0eb96fb
Showing 1 changed file with 107 additions and 81 deletions.
188 changes: 107 additions & 81 deletions niworkflows/anat/ants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
# nipype
from nipype.pipeline import engine as pe
from nipype.interfaces import utility as niu
from nipype.interfaces.fsl.maths import ApplyMask
from nipype.interfaces.ants import (
AI,
Atropos,
ImageMath,
MultiplyImages,
N4BiasFieldCorrection,
ResampleImageBySpacing,
ThresholdImage,
)

Expand All @@ -31,8 +29,9 @@
FixHeaderRegistration as Registration,
FixHeaderApplyTransforms as ApplyTransforms,
)
from ..interfaces.images import RegridToZooms
from ..interfaces.nibabel import ApplyMask, Binarize
from ..interfaces.utils import CopyXForm
from ..interfaces.nibabel import Binarize


ATROPOS_MODELS = {
Expand All @@ -41,6 +40,8 @@
"FLAIR": OrderedDict([("nclasses", 3), ("csf", 1), ("gm", 3), ("wm", 2)]),
}

_ants_version = Registration().version


def init_brain_extraction_wf(
name="brain_extraction_wf",
Expand Down Expand Up @@ -203,14 +204,8 @@ def init_brain_extraction_wf(
name="outputnode",
)

copy_xform = pe.Node(
CopyXForm(fields=["out_file", "out_mask", "bias_corrected", "bias_image"]),
name="copy_xform",
run_without_submitting=True,
)

trunc = pe.MapNode(
ImageMath(operation="TruncateImageIntensity", op2="0.01 0.999 256"),
ImageMath(operation="TruncateImageIntensity", op2="0.01 0.999 256", copy_header=True),
name="truncate_images",
iterfield=["op1"],
)
Expand All @@ -229,20 +224,15 @@ def init_brain_extraction_wf(
iterfield=["input_image"],
)

res_tmpl = pe.Node(
ResampleImageBySpacing(out_spacing=(4, 4, 4), apply_smoothing=True),
name="res_tmpl",
)
res_tmpl.inputs.input_image = tpl_target_path
res_target = pe.Node(
ResampleImageBySpacing(out_spacing=(4, 4, 4), apply_smoothing=True),
name="res_target",
)
res_tmpl = pe.Node(RegridToZooms(in_file=tpl_target_path, zooms=(4, 4, 4), smooth=True),
name="res_tmpl")
res_target = pe.Node(RegridToZooms(zooms=(4, 4, 4), smooth=True), name="res_target")

lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_tmpl")
lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True),
name="lap_tmpl")
lap_tmpl.inputs.op1 = tpl_target_path
lap_target = pe.Node(
ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_target"
ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_target"
)
mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")
mrg_tmpl.inputs.in1 = tpl_target_path
Expand All @@ -263,7 +253,6 @@ def init_brain_extraction_wf(
)

# Tolerate missing ANTs at construction time
_ants_version = Registration().version
if _ants_version and parseversion(_ants_version) >= Version("2.3.0"):
init_aff.inputs.search_grid = (40, (0, 40, 40))

Expand All @@ -287,26 +276,20 @@ def init_brain_extraction_wf(
fixed_mask_trait += "s"

map_brainmask = pe.Node(
ApplyTransforms(interpolation="Gaussian", float=True),
ApplyTransforms(interpolation="Gaussian"),
name="map_brainmask",
mem_gb=1,
)
map_brainmask.inputs.input_image = str(tpl_mask_path)

thr_brainmask = pe.Node(
ThresholdImage(
dimension=3, th_low=0.5, th_high=1.0, inside_value=1, outside_value=0
dimension=3, th_low=0.5, th_high=1.0, inside_value=1, outside_value=0,
copy_header=True,
),
name="thr_brainmask",
)

# Morphological dilation, radius=2
dil_brainmask = pe.Node(ImageMath(operation="MD", op2="2"), name="dil_brainmask")
# Get largest connected component
get_brainmask = pe.Node(
ImageMath(operation="GetLargestComponent"), name="get_brainmask"
)

# Refine INU correction
inu_n4_final = pe.MapNode(
N4BiasFieldCorrection(
Expand Down Expand Up @@ -340,47 +323,37 @@ def init_brain_extraction_wf(
# fmt: off
wf.connect([
(inputnode, trunc, [("in_files", "op1")]),
(inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
(inputnode, inu_n4_final, [("in_files", "input_image")]),
(inputnode, init_aff, [("in_mask", "fixed_image_mask")]),
(inputnode, norm, [("in_mask", fixed_mask_trait)]),
(inputnode, map_brainmask, [(("in_files", _pop), "reference_image")]),
(trunc, inu_n4, [("output_image", "input_image")]),
(inu_n4, res_target, [(("output_image", _pop), "input_image")]),
(res_tmpl, init_aff, [("output_image", "fixed_image")]),
(res_target, init_aff, [("output_image", "moving_image")]),
(inu_n4, res_target, [(("output_image", _pop), "in_file")]),
(res_tmpl, init_aff, [("out_file", "fixed_image")]),
(res_target, init_aff, [("out_file", "moving_image")]),
(init_aff, norm, [("output_transform", "initial_moving_transform")]),
(norm, map_brainmask, [
("reverse_transforms", "transforms"),
("reverse_invert_flags", "invert_transform_flags"),
]),
(map_brainmask, thr_brainmask, [("output_image", "input_image")]),
(thr_brainmask, dil_brainmask, [("output_image", "op1")]),
(dil_brainmask, get_brainmask, [("output_image", "op1")]),
(map_brainmask, inu_n4_final, [("output_image", "weight_image")]),
(inu_n4_final, apply_mask, [("output_image", "in_file")]),
(get_brainmask, apply_mask, [("output_image", "mask_file")]),
(get_brainmask, copy_xform, [("output_image", "out_mask")]),
(apply_mask, copy_xform, [("out_file", "out_file")]),
(inu_n4_final, copy_xform, [
("output_image", "bias_corrected"),
("bias_image", "bias_image"),
]),
(copy_xform, outputnode, [
("out_file", "out_file"),
("out_mask", "out_mask"),
("bias_corrected", "bias_corrected"),
("bias_image", "bias_image"),
]),
(thr_brainmask, apply_mask, [("output_image", "in_mask")]),
(thr_brainmask, outputnode, [("output_image", "out_mask")]),
(inu_n4_final, outputnode, [("output_image", "bias_corrected"),
("bias_image", "bias_image")]),
(apply_mask, outputnode, [("out_file", "out_file")]),
])
# fmt: on

if use_laplacian:
lap_tmpl = pe.Node(
ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_tmpl"
ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_tmpl"
)
lap_tmpl.inputs.op1 = tpl_target_path
lap_target = pe.Node(
ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_target"
ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_target"
)
mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")
mrg_tmpl.inputs.in1 = tpl_target_path
Expand Down Expand Up @@ -411,28 +384,23 @@ def init_brain_extraction_wf(
omp_nthreads=omp_nthreads,
mem_gb=mem_gb,
in_segmentation_model=atropos_model,
)
sel_wm = pe.Node(
niu.Select(index=atropos_model[-1] - 1),
name="sel_wm",
run_without_submitting=True,
bspline_fitting_distance=bspline_fitting_distance,
)

# fmt: off
wf.disconnect([
(get_brainmask, apply_mask, [("output_image", "mask_file")]),
(copy_xform, outputnode, [("out_mask", "out_mask")]),
(thr_brainmask, outputnode, [("output_image", "out_mask")]),
(inu_n4_final, outputnode, [("output_image", "bias_corrected"),
("bias_image", "bias_image")]),
(apply_mask, outputnode, [("out_file", "out_file")]),
])
wf.connect([
(inu_n4, atropos_wf, [("output_image", "inputnode.in_files")]),
(inu_n4_final, atropos_wf, [("output_image", "inputnode.in_files")]),
(thr_brainmask, atropos_wf, [("output_image", "inputnode.in_mask")]),
(get_brainmask, atropos_wf, [
("output_image", "inputnode.in_mask_dilated"),
]),
(atropos_wf, sel_wm, [("outputnode.out_tpms", "inlist")]),
(sel_wm, inu_n4_final, [("out", "weight_image")]),
(atropos_wf, apply_mask, [("outputnode.out_mask", "mask_file")]),
(atropos_wf, outputnode, [
("outputnode.out_file", "out_file"),
("outputnode.bias_corrected", "bias_corrected"),
("outputnode.out_mask", "bias_image"),
("outputnode.out_mask", "out_mask"),
("outputnode.out_segm", "out_segm"),
("outputnode.out_tpms", "out_tpms"),
Expand All @@ -449,6 +417,7 @@ def init_atropos_wf(
mem_gb=3.0,
padding=10,
in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()),
bspline_fitting_distance=200,
):
"""
Create an ANTs' ATROPOS workflow for brain tissue segmentation.
Expand Down Expand Up @@ -514,19 +483,24 @@ def init_atropos_wf(
"""
wf = pe.Workflow(name)

out_fields = ["bias_corrected", "bias_image", "out_mask", "out_segm", "out_tpms"]

inputnode = pe.Node(
niu.IdentityInterface(fields=["in_files", "in_mask", "in_mask_dilated"]),
niu.IdentityInterface(fields=["in_files", "in_mask"]),
name="inputnode",
)
outputnode = pe.Node(
niu.IdentityInterface(fields=["out_mask", "out_segm", "out_tpms"]),
name="outputnode",
)
outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"] + out_fields),
name="outputnode")

copy_xform = pe.Node(
CopyXForm(fields=["out_mask", "out_segm", "out_tpms"]),
name="copy_xform",
run_without_submitting=True,
copy_xform = pe.Node(CopyXForm(fields=out_fields),
name="copy_xform", run_without_submitting=True)

# Morphological dilation, radius=2
dil_brainmask = pe.Node(ImageMath(operation="MD", op2="2", copy_header=True),
name="dil_brainmask")
# Get largest connected component
get_brainmask = pe.Node(
ImageMath(operation="GetLargestComponent", copy_header=True), name="get_brainmask"
)

# Run atropos (core node)
Expand All @@ -549,10 +523,10 @@ def init_atropos_wf(

# massage outputs
pad_segm = pe.Node(
ImageMath(operation="PadImage", op2="%d" % padding), name="02_pad_segm"
ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False), name="02_pad_segm"
)
pad_mask = pe.Node(
ImageMath(operation="PadImage", op2="%d" % padding), name="03_pad_mask"
ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False), name="03_pad_mask"
)

# Split segmentation in binary masks
Expand Down Expand Up @@ -649,15 +623,57 @@ def init_atropos_wf(

msk_conform = pe.Node(niu.Function(function=_conform_mask), name="msk_conform")
merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]), name="merge_tpms")

sel_wm = pe.Node(
niu.Select(index=in_segmentation_model[-1] - 1),
name="sel_wm",
run_without_submitting=True,
)

copy_xform_wm = pe.Node(CopyXForm(fields=["depad_wm"]),
name="copy_xform_wm", run_without_submitting=True)

# Refine INU correction
inu_n4_final = pe.MapNode(
N4BiasFieldCorrection(
dimension=3,
save_bias=True,
copy_header=True,
n_iterations=[50] * 5,
convergence_threshold=1e-7,
shrink_factor=4,
bspline_fitting_distance=bspline_fitting_distance,
),
n_procs=omp_nthreads,
name="inu_n4_final",
iterfield=["input_image"],
)
if _ants_version and parseversion(_ants_version) >= Version("2.1.0"):
inu_n4_final.inputs.rescale_intensities = True
else:
warn(
"""\
Found ANTs version %s, which is too old. Please consider upgrading to 2.1.0 or \
greater so that the --rescale-intensities option is available with \
N4BiasFieldCorrection."""
% _ants_version,
DeprecationWarning,
)

# Apply mask
apply_mask = pe.MapNode(ApplyMask(), iterfield=["in_file"], name="apply_mask")

# fmt: off
wf.connect([
(inputnode, dil_brainmask, [("in_mask", "op1")]),
(inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
(inputnode, copy_xform_wm, [(("in_files", _pop), "hdr_file")]),
(inputnode, pad_mask, [("in_mask", "op1")]),
(inputnode, atropos, [
("in_files", "intensity_images"),
("in_mask_dilated", "mask_image"),
]),
(inputnode, atropos, [("in_files", "intensity_images")]),
(inputnode, inu_n4_final, [("in_files", "input_image")]),
(inputnode, msk_conform, [(("in_files", _pop), "in_reference")]),
(dil_brainmask, get_brainmask, [("output_image", "op1")]),
(get_brainmask, atropos, [("output_image", "mask_image")]),
(atropos, pad_segm, [("classified_image", "op1")]),
(pad_segm, sel_labels, [("output_image", "in_segm")]),
(sel_labels, get_wm, [("out_wm", "op1")]),
Expand Down Expand Up @@ -694,7 +710,17 @@ def init_atropos_wf(
(msk_conform, copy_xform, [("out", "out_mask")]),
(depad_segm, copy_xform, [("output_image", "out_segm")]),
(merge_tpms, copy_xform, [("out", "out_tpms")]),
(merge_tpms, sel_wm, [("out", "inlist")]),
(sel_wm, copy_xform_wm, [("out", "depad_wm")]),
(copy_xform_wm, inu_n4_final, [("depad_wm", "weight_image")]),
(inu_n4_final, copy_xform, [("output_image", "bias_corrected"),
("bias_image", "bias_image")]),
(copy_xform, apply_mask, [("bias_corrected", "in_file"),
("out_mask", "in_mask")]),
(apply_mask, outputnode, [("out_file", "out_file")]),
(copy_xform, outputnode, [
("bias_corrected", "bias_corrected"),
("bias_image", "bias_image"),
("out_mask", "out_mask"),
("out_segm", "out_segm"),
("out_tpms", "out_tpms"),
Expand Down

0 comments on commit 0eb96fb

Please sign in to comment.