diff --git a/niworkflows/anat/ants.py b/niworkflows/anat/ants.py index 2e8d00f08e9..5f063a802fa 100644 --- a/niworkflows/anat/ants.py +++ b/niworkflows/anat/ants.py @@ -362,7 +362,21 @@ def init_brain_extraction_wf( map_wmmask = pe.Node( ApplyTransforms(interpolation="Gaussian"), name="map_wmmask", mem_gb=1, ) - map_wmmask.inputs.input_image = str(wm_tpm) + + # Add the brain stem if it is found. + bstem_tpm = ( + get_template(in_template, label="BS", suffix="probseg", **common_spec) or None + ) + if bstem_tpm: + full_wm = pe.Node(niu.Function(function=_imsum), name="full_wm") + full_wm.inputs.op1 = str(wm_tpm) + full_wm.inputs.op2 = str(bstem_tpm) + # fmt: off + wf.connect([ + (full_wm, map_wmmask, [("out", "input_image")]) + ]) + else: + map_wmmask.inputs.input_image = str(wm_tpm) # fmt: off wf.disconnect([ (map_brainmask, inu_n4_final, [("output_image", "weight_image")]), @@ -804,23 +818,22 @@ def _argmax(in_dice): run_without_submitting=True) overlap = pe.Node(FuzzyOverlap(), name="overlap", run_without_submitting=True) - apply_wm_prior = pe.Node( - MultiplyImages(dimension=3, output_product_image="regularized_wm.nii.gz",), - name="apply_wm_prior", - ) + apply_wm_prior = pe.Node(niu.Function(function=_improd), name="apply_wm_prior") + # fmt: off wf.disconnect([ (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]), ]) wf.connect([ - (inputnode, apply_wm_prior, [("wm_prior", "second_input")]), + (inputnode, apply_wm_prior, [("in_mask", "in_mask"), + ("wm_prior", "op2")]), (inputnode, match_wm, [("wm_prior", "value")]), (atropos, match_wm, [("posteriors", "reference")]), (atropos, overlap, [("posteriors", "in_ref")]), (match_wm, overlap, [("out", "in_tst")]), (overlap, sel_wm, [(("class_fdi", _argmax), "index")]), - (copy_xform_wm, apply_wm_prior, [("wm_map", "first_input")]), - (apply_wm_prior, inu_n4_final, [("output_product_image", "weight_image")]), + (copy_xform_wm, apply_wm_prior, [("wm_map", "op1")]), + (apply_wm_prior, inu_n4_final, [("out", "weight_image")]), ]) # fmt: on return wf @@ -1052,3 +1065,39 @@ def _conform_mask(in_mask, in_reference): def _matchlen(value, reference): return [value] * len(reference) + + +def _imsum(op1, op2, out_file=None): + import nibabel as nb + im1 = nb.load(op1) + + data = im1.get_fdata() + nb.load(op2).get_fdata() + data /= data.max() + nii = nb.Nifti1Image(data, im1.affine, im1.header) + + if out_file is None: + from pathlib import Path + out_file = str((Path() / "summap.nii.gz").absolute()) + + nii.to_filename(out_file) + return out_file + + +def _improd(op1, op2, in_mask, out_file=None): + import nibabel as nb + im1 = nb.load(op1) + + data = im1.get_fdata() * nb.load(op2).get_fdata() + mskdata = nb.load(in_mask).get_fdata() > 0 + data[~mskdata] = 0 + data[data < 0] = 0 + data /= data.max() + data = 0.5 * (data + mskdata) + nii = nb.Nifti1Image(data, im1.affine, im1.header) + + if out_file is None: + from pathlib import Path + out_file = str((Path() / "prodmap.nii.gz").absolute()) + + nii.to_filename(out_file) + return out_file