Skip to content

Commit

Permalink
enh: add brainstem as part of the prior & brainmask for final N4
Browse files Browse the repository at this point in the history
Addresses issues like #567 (comment)
  • Loading branch information
oesteban committed Sep 22, 2020
1 parent 18639b7 commit a89688e
Showing 1 changed file with 58 additions and 11 deletions.
69 changes: 58 additions & 11 deletions niworkflows/anat/ants.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,22 @@ def init_brain_extraction_wf(
map_wmmask = pe.Node(
ApplyTransforms(interpolation="Gaussian"), name="map_wmmask", mem_gb=1,
)
map_wmmask.inputs.input_image = str(wm_tpm)

# Add the brain stem if it is found.
bstem_tpm = (
get_template(in_template, label="BS", suffix="probseg", **common_spec) or None
)
if bstem_tpm:
full_wm = pe.Node(niu.Function(function=_imsum), name="full_wm")
full_wm.inputs.op1 = str(wm_tpm)
full_wm.inputs.op2 = str(bstem_tpm)
# fmt: off
wf.connect([
(full_wm, map_wmmask, [("out", "input_image")])
])
# fmt: on
else:
map_wmmask.inputs.input_image = str(wm_tpm)
# fmt: off
wf.disconnect([
(map_brainmask, inu_n4_final, [("output_image", "weight_image")]),
Expand Down Expand Up @@ -783,26 +798,22 @@ def _argmax(in_dice):
run_without_submitting=True)
overlap = pe.Node(FuzzyOverlap(), name="overlap", run_without_submitting=True)

apply_wm_prior = pe.Node(
MultiplyImages(
dimension=3,
output_product_image="regularized_wm.nii.gz",
),
name="apply_wm_prior",
)
apply_wm_prior = pe.Node(niu.Function(function=_improd), name="apply_wm_prior")

# fmt: off
wf.disconnect([
(copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]),
])
wf.connect([
(inputnode, apply_wm_prior, [("wm_prior", "second_input")]),
(inputnode, apply_wm_prior, [("in_mask", "in_mask"),
("wm_prior", "op2")]),
(inputnode, match_wm, [("wm_prior", "value")]),
(atropos, match_wm, [("posteriors", "reference")]),
(atropos, overlap, [("posteriors", "in_ref")]),
(match_wm, overlap, [("out", "in_tst")]),
(overlap, sel_wm, [(("class_fdi", _argmax), "index")]),
(copy_xform_wm, apply_wm_prior, [("wm_map", "first_input")]),
(apply_wm_prior, inu_n4_final, [("output_product_image", "weight_image")]),
(copy_xform_wm, apply_wm_prior, [("wm_map", "op1")]),
(apply_wm_prior, inu_n4_final, [("out", "weight_image")]),
])
# fmt: on
return wf
Expand Down Expand Up @@ -1034,3 +1045,39 @@ def _conform_mask(in_mask, in_reference):

def _matchlen(value, reference):
return [value] * len(reference)


def _imsum(op1, op2, out_file=None):
import nibabel as nb
im1 = nb.load(op1)

data = im1.get_fdata() + nb.load(op2).get_fdata()
data /= data.max()
nii = nb.Nifti1Image(data, im1.affine, im1.header)

if out_file is None:
from pathlib import Path
out_file = str((Path() / "summap.nii.gz").absolute())

nii.to_filename(out_file)
return out_file


def _improd(op1, op2, in_mask, out_file=None):
import nibabel as nb
im1 = nb.load(op1)

data = im1.get_fdata() * nb.load(op2).get_fdata()
mskdata = nb.load(in_mask).get_fdata() > 0
data[~mskdata] = 0
data[data < 0] = 0
data /= data.max()
data = 0.5 * (data + mskdata)
nii = nb.Nifti1Image(data, im1.affine, im1.header)

if out_file is None:
from pathlib import Path
out_file = str((Path() / "prodmap.nii.gz").absolute())

nii.to_filename(out_file)
return out_file

0 comments on commit a89688e

Please sign in to comment.