From 934f4a7512effcdcef6365226ec1073d786e92e1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 29 Apr 2024 12:58:15 -0700 Subject: [PATCH 01/11] [REF] Refactor segmentation code into many files and pimms system --- AFQ/api/bundle_dict.py | 1 + AFQ/api/group.py | 5 +- AFQ/api/participant.py | 3 +- AFQ/api/utils.py | 8 +- AFQ/bundle_rec/__init__.py | 0 AFQ/bundle_rec/cleaning.py | 186 +++ AFQ/bundle_rec/criteria.py | 386 +++++ AFQ/bundle_rec/curvature.py | 52 + AFQ/bundle_rec/preprocess.py | 73 + AFQ/bundle_rec/recognize.py | 290 ++++ AFQ/bundle_rec/roi.py | 94 ++ AFQ/bundle_rec/utils.py | 222 +++ AFQ/data/fetch.py | 10 +- AFQ/segmentation.py | 1444 ------------------- AFQ/tasks/data.py | 9 +- AFQ/tasks/segmentation.py | 16 +- AFQ/tasks/utils.py | 5 +- AFQ/tests/test_api.py | 5 +- AFQ/tests/test_segmentation.py | 161 +-- AFQ/utils/bin.py | 14 +- docs/source/_progressbars.py | 2 - docs/source/reference/bundledict.rst | 22 +- examples/howto_examples/plot_recobundles.py | 5 +- 23 files changed, 1423 insertions(+), 1590 deletions(-) create mode 100644 AFQ/bundle_rec/__init__.py create mode 100644 AFQ/bundle_rec/cleaning.py create mode 100644 AFQ/bundle_rec/criteria.py create mode 100644 AFQ/bundle_rec/curvature.py create mode 100644 AFQ/bundle_rec/preprocess.py create mode 100644 AFQ/bundle_rec/recognize.py create mode 100644 AFQ/bundle_rec/roi.py create mode 100644 AFQ/bundle_rec/utils.py delete mode 100644 AFQ/segmentation.py diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 3195c9b35..e4b8b3d09 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -684,6 +684,7 @@ def __setitem__(self, key, value): "in the BundleDict.")) +# TODO: refactor the recobundles parts of this class BundleDict(MutableMapping): """ Create a bundle dictionary, needed for the segmentation. diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 892ddefc1..76a297a60 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -590,11 +590,8 @@ def export_all(self, viz=True, afqbrowser=True, xforms=True, Default: True """ start_time = time() - seg_params = self.export("segmentation_params", collapse=False)[ - self.valid_sub_list[0]][self.valid_ses_list[0]] - seg_algo = seg_params.get("seg_algo", "AFQ") - export_all_helper(self, seg_algo, xforms, indiv, viz) + export_all_helper(self, xforms, indiv, viz) self.combine_profiles() if afqbrowser: diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index da543a3a3..9b87bda83 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -198,8 +198,7 @@ def export_all(self, viz=True, xforms=True, Default: True """ start_time = time() - seg_algo = self.export("segmentation_params").get("seg_algo", "AFQ") - export_all_helper(self, seg_algo, xforms, indiv, viz) + export_all_helper(self, xforms, indiv, viz) self.logger.info( f"Time taken for export all: {time() - start_time}") diff --git a/AFQ/api/utils.py b/AFQ/api/utils.py index da789fa37..cf5fc73d7 100644 --- a/AFQ/api/utils.py +++ b/AFQ/api/utils.py @@ -109,7 +109,7 @@ def check_attribute(attr_name): f"{attr_name} not found for export. {valid_exports_string}") -def export_all_helper(api_afq_object, seg_algo, xforms, indiv, viz): +def export_all_helper(api_afq_object, xforms, indiv, viz): if xforms: try: api_afq_object.export("b0_warped") @@ -122,8 +122,7 @@ def export_all_helper(api_afq_object, seg_algo, xforms, indiv, viz): if indiv: api_afq_object.export("indiv_bundles") - if seg_algo == "AFQ": - api_afq_object.export("rois") + api_afq_object.export("rois") api_afq_object.export("sl_counts") api_afq_object.export("median_bundle_lengths") api_afq_object.export("profiles") @@ -138,5 +137,4 @@ def export_all_helper(api_afq_object, seg_algo, xforms, indiv, viz): else: api_afq_object.logger.warning(plot_err_message) api_afq_object.export("all_bundles_figure") - if seg_algo == "AFQ": - api_afq_object.export("indiv_bundles_figures") + api_afq_object.export("indiv_bundles_figures") diff --git a/AFQ/bundle_rec/__init__.py b/AFQ/bundle_rec/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/AFQ/bundle_rec/cleaning.py b/AFQ/bundle_rec/cleaning.py new file mode 100644 index 000000000..cbc02948f --- /dev/null +++ b/AFQ/bundle_rec/cleaning.py @@ -0,0 +1,186 @@ +import numpy as np +import logging + +from scipy.stats import zscore + +import dipy.tracking.streamline as dts +from dipy.stats.analysis import gaussian_weights +from dipy.io.stateful_tractogram import StatefulTractogram, Space + +import AFQ.bundle_rec.utils as abu + + +logger = logging.getLogger('AFQ') + + +def clean_by_orientation(streamlines, primary_axis, tol=None): + """ + Compute the cardinal orientation of each streamline + + Parameters + ---------- + streamlines : sequence of N by 3 arrays + Where N is number of nodes in the array, the collection of + streamlines to filter down to. + + Returns + ------- + cleaned_idx, indicies of streamlines that passed cleaning, + logical_and of other two returns + along_accepted_idx, indices of streamlines that passed + cleaning along the bundle + end_accepted_idx, indices of streamlines that passed + cleaning based on difference between endpoints of bundle + """ + axis_diff = np.zeros((len(streamlines), 3)) + endpoint_diff = np.zeros((len(streamlines), 3)) + for ii, sl in enumerate(streamlines): + # endpoint diff is between first and last + endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :]) + # axis diff is difference between the nodes, along + axis_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0) + + orientation_along = np.argmax(axis_diff, axis=1) + along_accepted_idx = orientation_along == primary_axis + if tol is not None: + percentage_primary = 100 * axis_diff[:, primary_axis] / np.sum( + axis_diff, axis=1) + logger.debug(( + "Maximum primary percentage found: " + f"{np.max(percentage_primary)}")) + along_accepted_idx = np.logical_and( + along_accepted_idx, percentage_primary > tol) + + orientation_end = np.argmax(endpoint_diff, axis=1) + end_accepted_idx = orientation_end == primary_axis + + cleaned_idx = np.logical_and( + along_accepted_idx, + end_accepted_idx) + + return cleaned_idx + + +def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, + length_threshold=4, min_sl=20, stat='mean', + return_idx=False): + """ + Clean a segmented fiber group based on the Mahalnobis distance of + each streamline + + Parameters + ---------- + tg : StatefulTractogram class instance or ArraySequence + A whole-brain tractogram to be segmented. + n_points : int, optional + Number of points to resample streamlines to. + Default: 100 + clean_rounds : int, optional. + Number of rounds of cleaning based on the Mahalanobis distance from + the mean of extracted bundles. Default: 5 + distance_threshold : float, optional. + Threshold of cleaning based on the Mahalanobis distance (the units are + standard deviations). Default: 3. + length_threshold: float, optional + Threshold for cleaning based on length (in standard deviations). Length + of any streamline should not be *more* than this number of stdevs from + the mean length. + min_sl : int, optional. + Number of streamlines in a bundle under which we will + not bother with cleaning outliers. Default: 20. + stat : callable or str, optional. + The statistic of each node relative to which the Mahalanobis is + calculated. Default: `np.mean` (but can also use median, etc.) + return_idx : bool + Whether to return indices in the original streamlines. + Default: False. + Returns + ------- + A StatefulTractogram class instance containing only the streamlines + that have a Mahalanobis distance smaller than `clean_threshold` from + the mean of each one of the nodes. + """ + # Convert string to callable, if that's what you got. + if isinstance(stat, str): + stat = getattr(np, stat) + + if hasattr(tg, "streamlines"): + streamlines = tg.streamlines + else: + streamlines = dts.Streamlines(tg) + + # We don't even bother if there aren't enough streamlines: + if len(streamlines) < min_sl: + if return_idx: + return tg, np.arange(len(streamlines)) + else: + return tg + + # Resample once up-front: + fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) + + # Keep this around, so you can use it for indexing at the very end: + idx = np.arange(len(fgarray)) + # get lengths of each streamline + lengths = np.array([sl.shape[0] for sl in streamlines]) + # We'll only do this for clean_rounds + rounds_elapsed = 0 + idx_belong = idx + while (rounds_elapsed < clean_rounds) and (np.sum(idx_belong) > min_sl): + # Update by selection: + idx = idx[idx_belong] + fgarray = fgarray[idx_belong] + lengths = lengths[idx_belong] + rounds_elapsed += 1 + + # This calculates the Mahalanobis for each streamline/node: + m_dist = gaussian_weights( + fgarray, return_mahalnobis=True, + n_points=n_points, stat=stat) + logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}") + logger.debug(f"Shape of m_dist: {m_dist.shape}") + logger.debug(f"Maximum m_dist: {np.max(m_dist)}") + logger.debug(( + f"Maximum m_dist for each fiber: " + f"{np.max(m_dist, axis=1)}")) + + length_z = zscore(lengths) + logger.debug(f"Shape of length_z: {length_z.shape}") + logger.debug(f"Maximum length_z: {np.max(length_z)}") + logger.debug(( + "length_z for each fiber: " + f"{length_z}")) + + if not ( + np.any(m_dist > distance_threshold) + or np.any(length_z > length_threshold)): + break + # Select the fibers that have Mahalanobis smaller than the + # threshold for all their nodes: + idx_dist = np.all(m_dist < distance_threshold, axis=-1) + idx_len = length_z < length_threshold + idx_belong = np.logical_and(idx_dist, idx_len) + + if np.sum(idx_belong) < min_sl: + # need to sort and return exactly min_sl: + idx_belong = np.argsort(np.sum( + m_dist, axis=-1))[:min_sl].astype(int) + logger.debug(( + f"At rounds elapsed {rounds_elapsed}, " + "minimum streamlines reached")) + else: + idx_removed = idx_belong == 0 + logger.debug(( + f"Rounds elapsed: {rounds_elapsed}, " + f"num removed: {np.sum(idx_removed)}")) + logger.debug(f"Removed indicies: {np.where(idx_removed)[0]}") + + # Select based on the variable that was keeping track of things for us: + if hasattr(tg, "streamlines"): + out = StatefulTractogram(tg.streamlines[idx], tg, Space.VOX) + else: + out = streamlines[idx] + if return_idx: + return out, idx + else: + return out diff --git a/AFQ/bundle_rec/criteria.py b/AFQ/bundle_rec/criteria.py new file mode 100644 index 000000000..15b544ff8 --- /dev/null +++ b/AFQ/bundle_rec/criteria.py @@ -0,0 +1,386 @@ +import pimms +import numpy as np +import logging +from time import time + +import dipy.tracking.streamline as dts +from dipy.utils.parallel import paramap +from dipy.segment.clustering import QuickBundles +from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric +from dipy.segment.featurespeed import ResampleFeature +from dipy.io.streamline import load_tractogram +from dipy.segment.bundles import RecoBundles +from dipy.io.stateful_tractogram import StatefulTractogram, Space + +import AFQ.bundle_rec.utils as abu +import AFQ.bundle_rec.cleaning as abc +import AFQ.bundle_rec.curvature as abv +import AFQ.bundle_rec.roi as abr + +bundle_criterion_order = [ + "prob_map", "cross_midline", "start", "end", + "length", "primary_axis", "include", "exclude", + "recobundles", "qb_thresh"] + + +logger = logging.getLogger('AFQ') + + +@pimms.calc("prob_map") +def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold): + b_sls.initiate_selection("Prob. Map") + # using entire fgarray here only because it is the first step + fiber_probabilities = dts.values_from_volume( + bundle_def["prob_map"].get_fdata(), + preproc_imap["fgarray"], np.eye(4)) + fiber_probabilities = np.mean(fiber_probabilities, -1) + b_sls.select( + fiber_probabilities > prob_threshold, + "Prob. Map") + return 1 # You have to return something + + +@pimms.calc("cross_midline") +def cross_midline(b_sls, bundle_def, preproc_imap): + b_sls.initiate_selection("Cross Mid.") + accepted = preproc_imap["crosses"][b_sls.selected_fiber_idxs] + if not bundle_def["cross_midline"]: + accepted = np.invert(accepted) + b_sls.select(accepted, "Cross Mid.") + return 1 + + +@pimms.calc("start") +def start(b_sls, bundle_def, preproc_imap): + accept_idx = b_sls.initiate_selection("Startpoint") + abr.clean_by_endpoints( + b_sls.get_selected_sls(), + bundle_def["start"], + 0, + tol=preproc_imap["dist_to_atlas"], + flip_sls=b_sls.sls_flipped, + accepted_idxs=accept_idx) + if not b_sls.oriented_yet: + accepted_idx_flipped = abr.clean_by_endpoints( + b_sls.get_selected_sls(), + bundle_def["start"], + -1, + tol=preproc_imap["dist_to_atlas"]) + b_sls.reorient(accepted_idx_flipped) + accept_idx = np.logical_xor( + accepted_idx_flipped, accept_idx) + b_sls.select(accept_idx, "Startpoint") + return 1 + + +@pimms.calc("end") +def end(b_sls, bundle_def, preproc_imap): + accept_idx = b_sls.initiate_selection("endpoint") + abr.clean_by_endpoints( + b_sls.get_selected_sls(), + bundle_def["end"], + -1, + tol=preproc_imap["dist_to_atlas"], + flip_sls=b_sls.sls_flipped, + accepted_idxs=accept_idx) + if not b_sls.oriented_yet: + accepted_idx_flipped = abr.clean_by_endpoints( + b_sls.get_selected_sls(), + bundle_def["end"], + 0, + tol=preproc_imap["dist_to_atlas"]) + b_sls.reorient(accepted_idx_flipped) + accept_idx = np.logical_xor( + accepted_idx_flipped, accept_idx) + b_sls.select(accept_idx, "endpoint") + return 1 + + +@pimms.calc("length") +def length(b_sls, bundle_def, preproc_imap): + accept_idx = b_sls.initiate_selection("length") + min_len = bundle_def["length"].get( + "min_len", 0) / preproc_imap["vox_dim"] + max_len = bundle_def["length"].get( + "max_len", np.inf) / preproc_imap["vox_dim"] + for idx, sl in enumerate(b_sls.get_selected_sls()): + sl_len = np.sum( + np.linalg.norm(np.diff(sl, axis=0), axis=1)) + if sl_len >= min_len and sl_len <= max_len: + accept_idx[idx] = 1 + b_sls.select(accept_idx, "length") + return 1 + + +@pimms.calc("primary_axis") +def primary_axis(b_sls, bundle_def): + b_sls.initiate_selection("orientation") + accept_idx = abc.clean_by_orientation( + b_sls.get_selected_sls(), + bundle_def["primary_axis"], + bundle_def.get( + "primary_axis_percentage", None)) + b_sls.select(accept_idx, "orientation") + return 1 + + +@pimms.calc("include") +def include(b_sls, bundle_def, preproc_imap, max_includes, + parallel_segmentation): + accept_idx = b_sls.initiate_selection("include") + flip_using_include = len(bundle_def["include"]) > 1\ + and not b_sls.oriented_yet + + if f'inc_addtol' in bundle_def: + include_roi_tols = [] + for inc_tol in bundle_def["inc_addtol"]: + include_roi_tols.append(( + inc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"])**2) + else: + include_roi_tols = [preproc_imap["tol"]**2] * len(bundle_def["include"]) + + include_rois = [] + for include_roi in bundle_def["include"]: + include_rois.append(np.array( + np.where(include_roi.get_fdata())).T) + + # with parallel segmentation, the first for loop will + # only collect streamlines and does not need tqdm + if parallel_segmentation["engine"] != "serial": + inc_results = paramap( + abr.check_sl_with_inclusion, b_sls.get_selected_sls(), + func_args=[ + include_rois, include_roi_tols], + **parallel_segmentation) + + else: + inc_results = abr.check_sls_with_inclusion( + b_sls.get_selected_sls(), + include_rois, + include_roi_tols) + + roi_dists = -np.ones( + (len(b_sls), max_includes), + dtype=np.int32) + if flip_using_include: + to_flip = np.ones_like(accept_idx, dtype=np.bool8) + for sl_idx, inc_result in enumerate(inc_results): + sl_accepted, sl_dist = inc_result + + if sl_accepted: + if len(sl_dist) > 1: + roi_dists[sl_idx, :len(sl_dist)] = [ + np.argmin(dist, 0)[0] + for dist in sl_dist] + first_roi_idx = roi_dists[sl_idx, 0] + last_roi_idx = roi_dists[ + sl_idx, len(sl_dist) - 1] + # Only accept SLs that, when cut, are meaningful + if (len(sl_dist) < 2) or abs( + first_roi_idx - last_roi_idx) > 1: + # Flip sl if it is close to second ROI + # before its close to the first ROI + if flip_using_include: + to_flip[sl_idx] =\ + first_roi_idx > last_roi_idx + if to_flip[sl_idx]: + roi_dists[sl_idx, :len(sl_dist)] =\ + np.flip(roi_dists[ + sl_idx, :len(sl_dist)]) + accept_idx[sl_idx] = 1 + else: + accept_idx[sl_idx] = 1 + # see https://github.com/joblib/joblib/issues/945 + if ( + (parallel_segmentation.get( + "engine", "joblib") != "serial") + and (parallel_segmentation.get( + "backend", "loky") == "loky")): + from joblib.externals.loky import get_reusable_executor + get_reusable_executor().shutdown(wait=True) + b_sls.roi_dists = roi_dists + if flip_using_include: + b_sls.reorient(to_flip) + b_sls.select(accept_idx, "include") + return 1 + +# Filters streamlines by how well they match +# a curve in orientation and shape but not scale + + +@pimms.calc("curvature") +def curvature(b_sls, bundle_def, mapping, img, save_intermediates): + accept_idx = b_sls.initiate_selection("curvature") + if "sft" in bundle_def["curvature"]: + ref_sl = bundle_def["curvature"]["sft"] + else: + ref_sl = load_tractogram( + bundle_def["curvature"]["path"], "same", + bbox_valid_check=False) + moved_ref_sl = abu.move_streamlines( + ref_sl, "subject", mapping, img, + save_intermediates=save_intermediates) + moved_ref_sl.to_vox() + moved_ref_sl = moved_ref_sl.streamlines[0] + moved_ref_curve = abv.sl_curve( + moved_ref_sl, + len(moved_ref_sl)) + ref_curve_threshold = np.radians(bundle_def["curvature"].get( + "thresh", 10)) + cut = bundle_def["curvature"].get("cut", True) + for idx, sl in enumerate(b_sls.get_selected_sls( + cut=cut, flip=True)): + if len(sl) > 1: + this_sl_curve = abv.sl_curve(sl, len(moved_ref_sl)) + dist = abv.sl_curve_dist(this_sl_curve, moved_ref_curve) + if dist <= ref_curve_threshold: + accept_idx[idx] = 1 + b_sls.select(accept_idx, "curvature", cut=cut) + return 1 + + +@pimms.calc("exclude") +def exclude(b_sls, bundle_def, preproc_imap): + accept_idx = b_sls.initiate_selection("exclude") + if f'exc_addtol' in bundle_def: + exclude_roi_tols = [] + for exc_tol in bundle_def["exc_addtol"]: + exclude_roi_tols.append(( + exc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"])**2) + else: + exclude_roi_tols = [ + preproc_imap["tol"]**2] * len(bundle_def["exclude"]) + exclude_rois = [] + for exclude_roi in bundle_def["exclude"]: + exclude_rois.append(np.array( + np.where(exclude_roi.get_fdata())).T) + for sl_idx, sl in enumerate(b_sls.get_selected_sls()): + if abr.check_sl_with_exclusion( + sl, exclude_rois, exclude_roi_tols): + accept_idx[sl_idx] = 1 + b_sls.select(accept_idx, "exclude") + return 1 + + +@pimms.calc("recobundles") +def recobundles(b_sls, mapping, bundle_def, reg_template, img, refine_reco, + save_intermediates, rng, rb_recognize_params): + b_sls.initiate_selection("Recobundles") + moved_sl = abu.move_streamlines( + StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), + "template", mapping, reg_template, + save_intermediates=save_intermediates).streamlines + rb = RecoBundles(moved_sl, verbose=True, rng=rng) + _, rec_labels = rb.recognize( + bundle_def['recobundles']['sl'], + **rb_recognize_params) + if refine_reco: + _, rec_labels = rb.refine( + bundle_def['recobundles']['sl'], moved_sl[rec_labels], + **rb_recognize_params) + if not b_sls.oriented_yet: + standard_sl = next(iter(bundle_def['recobundles']['centroid'])) + oriented_idx = abu.orient_by_streamline( + moved_sl[rec_labels], + standard_sl) + b_sls.reorient(rec_labels[oriented_idx]) + b_sls.select(rec_labels, "Recobundles") + return 1 + + +@pimms.calc("qb_thresh") +def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges): + b_sls.initiate_selection("qb_thresh") + cut = clip_edges or ("bundlesection" in bundle_def) + qbx = QuickBundles( + bundle_def["qb_thresh"] / preproc_imap["vox_dim"], + AveragePointwiseEuclideanMetric( + ResampleFeature(nb_points=12))) + clusters = qbx.cluster(b_sls.get_selected_sls( + cut=cut, flip=True)) + cleaned_idx = clusters[np.argmax( + clusters.clusters_sizes())].indices + b_sls.select(cleaned_idx, "qb_thresh", cut=cut) + return 1 + + +@pimms.calc("mahalanobis") +def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params): + b_sls.initiate_selection("Mahalanobis") + clean_params = bundle_def.get("mahal", {}) + clean_params = { + **cleaning_params, + **clean_params} + clean_params["return_idx"] = True + cut = clip_edges or ("bundlesection" in bundle_def) + _, cleaned_idx = abc.clean_bundle( + b_sls.get_selected_sls(cut=cut, flip=True), + **clean_params) + b_sls.select(cleaned_idx, "Mahalanobis", cut=cut) + return 1 + + +def run_bundle_rec_plan( + bundle_dict, tg, mapping, img, reg_template, preproc_imap, + bundle_name, bundle_idx, bundle_to_flip, bundle_roi_dists, + bundle_decisions, + **segmentation_params): + # Warp ROIs + logger.info(f"Preparing ROIs for {bundle_name}") + start_time = time() + bundle_def = dict(bundle_dict.get_b_info(bundle_name)) + bundle_def.update(bundle_dict.transform_rois( + bundle_name, + mapping, + img.affine)) + logger.info(f"Time to prep ROIs: {time()-start_time}s") + + b_sls = abu.SlsBeingRecognized( + tg.streamlines, logger, + segmentation_params["save_intermediates"], + bundle_name, + img, len(bundle_def.get("include", []))) + + bundle_plan = {} + for criterion in bundle_criterion_order: + bundle_plan[criterion] = globals()[criterion] + bundle_plan["mahalanobis"] = mahalanobis + bundle_plan = pimms.plan(bundle_plan) + + inputs = {} + inputs["b_sls"] = b_sls + inputs["preproc_imap"] = preproc_imap + inputs["bundle_def"] = bundle_def + inputs["max_includes"] = bundle_dict.max_includes + inputs["mapping"] = mapping + inputs["img"] = img + inputs["reg_template"] = reg_template + for key, value in segmentation_params.items(): + inputs[key] = value + + bundle_imap = bundle_plan(inputs) + for criterion in bundle_criterion_order: + if b_sls and criterion in bundle_def: + bundle_imap[criterion] + bundle_imap["mahalanobis"] + + if b_sls and not b_sls.oriented_yet: + raise ValueError( + "pyAFQ was unable to consistently orient streamlines " + f"in bundle {bundle_name} using the provided ROIs. " + "This can be fixed by including at least 2 " + "waypoint ROIs, or by using " + "endpoint ROIs.") + + if b_sls: + bundle_to_flip[ + b_sls.selected_fiber_idxs, + bundle_idx] = b_sls.sls_flipped.copy() + bundle_decisions[ + b_sls.selected_fiber_idxs, + bundle_idx] = 1 + if hasattr(b_sls, "roi_dists"): + bundle_roi_dists[ + b_sls.selected_fiber_idxs, + bundle_idx + ] = b_sls.roi_dists.copy() diff --git a/AFQ/bundle_rec/curvature.py b/AFQ/bundle_rec/curvature.py new file mode 100644 index 000000000..390b6c019 --- /dev/null +++ b/AFQ/bundle_rec/curvature.py @@ -0,0 +1,52 @@ +import numpy as np + +import dipy.tracking.streamlinespeed as dps + + +def sl_curve(sl, n_points): + """ + Calculate the direction of the displacement between + each point along a streamline + + Parameters + ---------- + sl : 2d array-like + Streamline to calcualte displacements for. + n_points : int + Number of points to resample the streamline to + + Returns + ------- + 2d array of shape (len(sl)-1, 3) with displacements + between each point in sl normalized to 1. + """ + # Resample to a standardized number of points + resampled_sl = dps.set_number_of_points( + sl, + n_points) + + # displacement at each point + resampled_sl_diff = np.diff(resampled_sl, axis=0) + + # normalize this displacement + resampled_sl_diff = resampled_sl_diff / np.linalg.norm( + resampled_sl_diff, axis=1)[:, None] + + return resampled_sl_diff + + +def sl_curve_dist(curve1, curve2): + """ + Calculate the mean angle using the directions of displacement + between two streamlines + + Parameters + ---------- + curve1, curve2 : 2d array-like + Two curves calculated from sl_curve. + + Returns + ------- + The mean angle between each curve across all steps, in radians + """ + return np.mean(np.arccos(np.sum(curve1 * curve2, axis=1))) diff --git a/AFQ/bundle_rec/preprocess.py b/AFQ/bundle_rec/preprocess.py new file mode 100644 index 000000000..73628eddd --- /dev/null +++ b/AFQ/bundle_rec/preprocess.py @@ -0,0 +1,73 @@ +import numpy as np +import pimms +from time import time +import logging + +import dipy.tracking.streamline as dts + +import AFQ.bundle_rec.utils as abu + + +logger = logging.getLogger('AFQ') + + +@pimms.calc("tol", "dist_to_atlas", "vox_dim") +def incexc_tol(img, dist_to_waypoint, input_dist_to_atlas): + # We need to calculate the size of a voxel, so we can transform + # from mm to voxel units: + R = img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + + # Tolerance is set to the square of the distance to the corner + # because we are using the squared Euclidean distance in calls to + # `cdist` to make those calls faster. + if dist_to_waypoint is None: + tol = dts.dist_to_corner(img.affine) + else: + tol = dist_to_waypoint / vox_dim + dist_to_atlas = int(input_dist_to_atlas / vox_dim) + return tol, dist_to_atlas, vox_dim + + +@pimms.calc("fgarray") +def fgarray(tg): + """ + Streamlines resampled to 20 points. + """ + logger.info("Resampling Streamlines...") + start_time = time() + fg_array = np.array(abu.resample_tg(tg, 20)) + logger.info(( + "Streamlines Resampled " + f"(time: {time()-start_time}s)")) + return fg_array + + +@pimms.calc("crosses") +def crosses(fgarray, img): + """ + Classify the streamlines by whether they cross the midline. + Creates a crosses attribute which is an array of booleans. Each boolean + corresponds to a streamline, and is whether or not that streamline + crosses the midline. + """ + # What is the x,y,z coordinate of 0,0,0 in the template space? + zero_coord = np.dot(np.linalg.inv(img.affine), + np.array([0, 0, 0, 1])) + + return np.logical_and( + np.any(fgarray[:, :, 0] > zero_coord[0], axis=1), + np.any(fgarray[:, :, 0] < zero_coord[0], axis=1)) + + +# Things that can be calculated for multiple bundles at once +# (i.e., for a whole tractogram) go here +def get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas): + preproc_plan = pimms.Plan( + incexc_tol=incexc_tol, + fgarray=fgarray, + crosses=crosses) + return preproc_plan( + img=img, tg=tg, + dist_to_waypoint=dist_to_waypoint, + input_dist_to_atlas=dist_to_atlas) diff --git a/AFQ/bundle_rec/recognize.py b/AFQ/bundle_rec/recognize.py new file mode 100644 index 000000000..9e1771343 --- /dev/null +++ b/AFQ/bundle_rec/recognize.py @@ -0,0 +1,290 @@ +import numpy as np +import os.path as op +import os +import logging + +import dipy.tracking.streamlinespeed as dps +from dipy.io.stateful_tractogram import StatefulTractogram, Space + +import AFQ.bundle_rec.utils as abu +from AFQ.api.bundle_dict import BundleDict +from AFQ.bundle_rec.criteria import run_bundle_rec_plan +from AFQ.bundle_rec.preprocess import get_preproc_plan + + +logger = logging.getLogger('AFQ') + + +def recognize( + tg, + img, + mapping, + bundle_dict, + reg_template, + nb_points=False, + nb_streamlines=False, + clip_edges=False, + parallel_segmentation={"engine": "serial"}, + rb_recognize_params=dict( + model_clust_thr=1.25, + reduction_thr=25, + pruning_thr=12), + refine_reco=False, + prob_threshold=0, + dist_to_waypoint=None, + rng=None, + return_idx=False, + filter_by_endpoints=True, + dist_to_atlas=4, + save_intermediates=None, + cleaning_params={}): + """ + Segment streamlines into bundles. + + Parameters + ---------- + nb_points : int, boolean + Resample streamlines to nb_points number of points. + If False, no resampling is done. Default: False + nb_streamlines : int, boolean + Subsample streamlines to nb_streamlines. + If False, no subsampling is don. Default: False + clip_edges : bool + Whether to clip the streamlines to be only in between the ROIs. + Default: False + parallel_segmentation : dict or AFQ.api.BundleDict + How to parallelize segmentation across processes when performing + waypoint ROI segmentation. Set to {"engine": "serial"} to not + perform parallelization. Some engines may cause errors, depending + on the system. See ``dipy.utils.parallel.paramap`` for + details. + Default: {"engine": "serial"} + rb_recognize_params : dict + RecoBundles parameters for the recognize function. + refine_reco : bool + Whether to refine the RecoBundles segmentation. + prob_threshold : float. + Using AFQ Algorithm. + Initial cleaning of fiber groups is done using probability maps + from [Hua2008]_. Here, we choose an average probability that + needs to be exceeded for an individual streamline to be retained. + Default: 0. + dist_to_waypoint : float. + The distance that a streamline node has to be from the waypoint + ROI in order to be included or excluded. + If set to None (default), will be calculated as the + center-to-corner distance of the voxel in the diffusion data. + If a bundle has inc_addtol or exc_addtol in its bundle_dict, that + tolerance will be added to this distance. + For example, if you wanted to increase tolerance for the right + arcuate waypoint ROIs by 3 each, you could make the following + modification to your bundle_dict: + bundle_dict["Right Arcuate"]["inc_addtol"] = [3, 3] + Additional tolerances can also be negative. + rng : RandomState or int + If None, creates RandomState. + If int, creates RandomState with seed rng. + Used in RecoBundles Algorithm. + Default: None. + return_idx : bool + Whether to return the indices in the original streamlines as part + of the output of segmentation. + filter_by_endpoints: bool + Whether to filter the bundles based on their endpoints. + Default: True. + dist_to_atlas : float + If filter_by_endpoints is True, this is the required distance + from the endpoints to the atlas ROIs. + save_intermediates : str, optional + The full path to a folder into which intermediate products + are saved. Default: None, means no saving of intermediates. + cleaning_params : dict, optional + Cleaning params to pass to seg.clean_bundle. This will + override the default parameters of that method. However, this + can be overriden by setting the cleaning parameters in the + bundle_dict. Default: {}. + + References + ---------- + .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008) + Tract probability maps in stereotaxic spaces: analyses of white + matter anatomy and tract-specific quantification. Neuroimage 39: + 336-347 + .. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J. + Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles of + White Matter Properties: Automating Fiber-Tract Quantification" + PloS One 7 (11): e49790. + .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter + bundles using local and global streamline-based registration and + clustering, Neuroimage, 2017. + """ + if rng is None: + rng = np.random.RandomState() + elif isinstance(rng, int): + rng = np.random.RandomState(rng) + + if (save_intermediates is not None) and \ + (not op.exists(save_intermediates)): + os.makedirs(save_intermediates, exist_ok=True) + + logger.info("Preprocessing Streamlines") + tg = abu.read_tg(tg, nb_streamlines) + + # If resampling over-write the sft: + if nb_points: + tg = StatefulTractogram( + dps.set_number_of_points(tg.streamlines, nb_points), + tg, tg.space) + + if not isinstance(bundle_dict, BundleDict): + bundle_dict = BundleDict(bundle_dict) + + tg.to_vox() + n_streamlines = len(tg) + bundle_decisions = np.zeros( + (n_streamlines, len(bundle_dict)), + dtype=np.bool8) + bundle_to_flip = np.zeros( + (n_streamlines, len(bundle_dict)), + dtype=np.bool8) + bundle_roi_dists = -np.ones( + ( + n_streamlines, + len(bundle_dict), + bundle_dict.max_includes), + dtype=np.uint32) + + fiber_groups = {} + meta = {} + + preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas) + + logger.info("Assigning Streamlines to Bundles") + for bundle_idx, bundle_name in enumerate( + bundle_dict.bundle_names): + logger.info(f"Finding Streamlines for {bundle_name}") + run_bundle_rec_plan( + bundle_dict, tg, mapping, img, reg_template, preproc_imap, + bundle_name, bundle_idx, bundle_to_flip, bundle_roi_dists, + bundle_decisions, + clip_edges=clip_edges, + parallel_segmentation=parallel_segmentation, + rb_recognize_params=rb_recognize_params, + prob_threshold=prob_threshold, + refine_reco=refine_reco, + rng=rng, + return_idx=return_idx, + filter_by_endpoints=filter_by_endpoints, + save_intermediates=save_intermediates, + cleaning_params=cleaning_params) + + if save_intermediates is not None: + os.makedirs(save_intermediates, exist_ok=True) + bc_path = op.join(save_intermediates, + "sls_bundle_decisions.npy") + np.save(bc_path, bundle_decisions) + + conflicts = np.sum(np.sum(bundle_decisions, axis=1) > 1) + if conflicts > 0: + logger.warning(( + "Conflicts in bundle assignment detected. " + f"{conflicts} conflicts detected in total out of " + f"{n_streamlines} total streamlines." + "Defaulting to whichever bundle appears first" + "in the bundle_dict.")) + bundle_decisions = np.concatenate(( + bundle_decisions, np.ones((n_streamlines, 1))), axis=1) + bundle_decisions = np.argmax(bundle_decisions, -1) + + # Helper functions for formatting the results + def _return_empty(bundle_name): + """ + Helper function to return an empty dict under + some conditions. + """ + if return_idx: + fiber_groups[bundle_name] = {} + fiber_groups[bundle_name]['sl'] = StatefulTractogram( + [], img, Space.VOX) + fiber_groups[bundle_name]['idx'] = np.array([]) + else: + fiber_groups[bundle_name] = StatefulTractogram( + [], img, Space.VOX) + + def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip): + """ + Helper function to add a bundle to a fiber group. + """ + sl = abu.flip_sls( + sl, to_flip, + in_place=False) + + sl = StatefulTractogram( + sl, + img, + Space.VOX) + + if return_idx: + fiber_groups[b_name] = {} + fiber_groups[b_name]['sl'] = sl + fiber_groups[b_name]['idx'] = idx + else: + fiber_groups[b_name] = sl + + def _add_bundle_to_meta(bundle_name, b_def): + # remove keys that can never be serialized + for key in [ + 'include', 'exclude', 'prob_map', + 'start', 'end', 'curvature']: + b_def.pop(key, None) + meta[bundle_name] = b_def + + # We do another round through, so that we can orient all the + # streamlines within a bundle in the same orientation with respect to + # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0 + # to ROI1). + logger.info("Re-orienting streamlines to consistent directions") + for bundle_idx, bundle in enumerate(bundle_dict.bundle_names): + logger.info(f"Processing {bundle}") + + select_idx = np.where(bundle_decisions == bundle_idx)[0] + + if len(select_idx) == 0: + # There's nothing here, set and move to the next bundle: + if "bundlesection" in bundle_dict.get_b_info(bundle): + for sb_name in bundle_dict.get_b_info(bundle)[ + "bundlesection"]: + _return_empty(sb_name) + else: + _return_empty(bundle) + continue + + # Use a list here, because ArraySequence doesn't support item + # assignment: + select_sl = list(tg.streamlines[select_idx]) + roi_dists = bundle_roi_dists[select_idx, bundle_idx, :] + n_includes = len(bundle_dict.get_b_info( + bundle).get("include", [])) + if clip_edges and n_includes > 1: + logger.info("Clipping Streamlines by ROI") + select_sl = abu.cut_sls_by_dist( + select_sl, roi_dists, + (0, n_includes - 1), in_place=True) + + to_flip = bundle_to_flip[select_idx, bundle_idx] + b_def = dict(bundle_dict.get_b_info(bundle_name)) + if "bundlesection" in b_def: + for sb_name, sb_include_cuts in bundle_dict.get_b_info( + bundle)["bundlesection"].items(): + bundlesection_select_sl = abu.cut_sls_by_dist( + select_sl, roi_dists, + sb_include_cuts, in_place=False) + _add_bundle_to_fiber_group( + sb_name, bundlesection_select_sl, select_idx, + to_flip) + _add_bundle_to_meta(sb_name, b_def) + else: + _add_bundle_to_fiber_group( + bundle, select_sl, select_idx, to_flip) + _add_bundle_to_meta(bundle, b_def) + return fiber_groups, meta diff --git a/AFQ/bundle_rec/roi.py b/AFQ/bundle_rec/roi.py new file mode 100644 index 000000000..26500c146 --- /dev/null +++ b/AFQ/bundle_rec/roi.py @@ -0,0 +1,94 @@ +import numpy as np +from scipy.spatial.distance import cdist +from scipy.ndimage import binary_dilation + + +def check_sls_with_inclusion(sls, include_rois, include_roi_tols): + for sl in sls: + yield check_sl_with_inclusion( + sl, + include_rois, + include_roi_tols) + + +def check_sl_with_inclusion(sl, include_rois, + include_roi_tols): + """ + Helper function to check that a streamline is close to a list of + inclusion ROIS. + """ + dist = [] + for ii, roi in enumerate(include_rois): + # Use squared Euclidean distance, because it's faster: + dist.append(cdist(sl, roi, 'sqeuclidean')) + if np.min(dist[-1]) > include_roi_tols[ii]: + # Too far from one of them: + return False, [] + # Apparently you checked all the ROIs and it was close to all of them + return True, dist + + +def check_sl_with_exclusion(sl, exclude_rois, + exclude_roi_tols): + """ Helper function to check that a streamline is not too close to a + list of exclusion ROIs. + """ + for ii, roi in enumerate(exclude_rois): + # Use squared Euclidean distance, because it's faster: + if np.min(cdist(sl, roi, 'sqeuclidean')) < exclude_roi_tols[ii]: + return False + # Either there are no exclusion ROIs, or you are not close to any: + return True + + +def clean_by_endpoints(streamlines, target, target_idx, tol=0, + flip_sls=None, accepted_idxs=None): + """ + Clean a collection of streamlines based on an endpoint ROI. + Filters down to only include items that have their start or end points + close to the targets. + Parameters + ---------- + streamlines : sequence of N by 3 arrays + Where N is number of nodes in the array, the collection of + streamlines to filter down to. + target: Nifti1Image + Nifti1Image containing a boolean representation of the ROI. + target_idx: int. + Index within each streamline to check if within the target region. + Typically 0 for startpoint ROIs or -1 for endpoint ROIs. + If using flip_sls, this becomes (len(sl) - this_idx - 1) % len(sl) + tol : int, optional + A distance tolerance (in units that the coordinates + of the streamlines are represented in). Default: 0, which means that + the endpoint is exactly in the coordinate of the target ROI. + flip_sls : 1d array, optional + Length is len(streamlines), whether to flip the streamline. + accepted_idxs : 1d array, optional + Boolean array, where entries correspond to eachs streamline, + and streamlines that pass cleaning will be set to 1. + Yields + ------- + boolean array of streamlines that survive cleaning. + """ + if accepted_idxs is None: + accepted_idxs = np.zeros(len(streamlines), dtype=np.bool8) + + if flip_sls is None: + flip_sls = np.zeros(len(streamlines)) + flip_sls = flip_sls.astype(int) + + roi = target.get_fdata() + if tol > 0: + roi = binary_dilation( + roi, + iterations=tol) + + for ii, sl in enumerate(streamlines): + this_idx = target_idx + if flip_sls[ii]: + this_idx = (len(sl) - this_idx - 1) % len(sl) + xx, yy, zz = sl[this_idx].astype(int) + accepted_idxs[ii] = roi[xx, yy, zz] + + return accepted_idxs diff --git a/AFQ/bundle_rec/utils.py b/AFQ/bundle_rec/utils.py new file mode 100644 index 000000000..a8c8aded5 --- /dev/null +++ b/AFQ/bundle_rec/utils.py @@ -0,0 +1,222 @@ +import numpy as np +import os.path as op +from time import time + +import logging + +from dipy.io.stateful_tractogram import StatefulTractogram, Space +from dipy.io.streamline import save_tractogram +import dipy.tracking.streamlinespeed as dps +import dipy.tracking.streamline as dts +from dipy.tracking.distances import bundles_distances_mdf + + +from AFQ.definitions.mapping import ConformedFnirtMapping + + +logger = logging.getLogger('AFQ') + + +def flip_sls(select_sl, idx_to_flip, in_place=False): + """ + Helper function to flip streamlines + """ + if in_place: + flipped_sl = select_sl + else: + flipped_sl = [None] * len(select_sl) + for ii, sl in enumerate(select_sl): + if idx_to_flip[ii]: + flipped_sl[ii] = sl[::-1] + else: + flipped_sl[ii] = sl + return flipped_sl + + +def cut_sls_by_dist(select_sl, roi_dists, roi_idxs, + in_place=False): + """ + Helper function to cut streamlines according to which points + are closest to certain rois. + + Parameters + ---------- + select_sl, streamlines to cut + roi_dists, distances from a given streamline to a given inclusion roi + roi_idxs, two indices into the list of inclusion rois to use for the cut + in_place, whether to modify select_sl + """ + if in_place: + cut_sl = select_sl + else: + cut_sl = [None] * len(select_sl) + + for idx, this_sl in enumerate(select_sl): + if roi_idxs[0] == -1: + min0 = 0 + else: + min0 = int(roi_dists[idx, roi_idxs[0]]) + if roi_idxs[1] == -1: + min1 = len(this_sl) + else: + min1 = int(roi_dists[idx, roi_idxs[1]]) + + # handle if sls not flipped + if min0 > min1: + min0, min1 = min1, min0 + + # If the point that is closest to the first ROI + # is the same as the point closest to the second ROI, + # include the surrounding points to make a streamline. + if min0 == min1: + min1 = min1 + 1 + min0 = min0 - 1 + + cut_sl[idx] = this_sl[min0:min1] + + return cut_sl + + +def read_tg(tg, nb_streamlines=None): + if nb_streamlines and len(tg) > nb_streamlines: + tg = StatefulTractogram.from_sft( + dts.select_random_set_of_streamlines( + tg.streamlines, + nb_streamlines + ), + tg) + return tg + + +def orient_by_streamline(sls, template_sl): + DM = bundles_distances_mdf( + sls, + [template_sl, template_sl[::-1]]) + return DM[:, 0] > DM[:, 1] + + +def move_streamlines(tg, to, mapping, img, save_intermediates=None): + """Move streamlines to or from template space. + + to : str + Either "template" or "subject". + mapping : ConformedMapping + Mapping to use to move streamlines. + img : Nifti1Image + Space to move streamlines to. + """ + tg_og_space = tg.space + if isinstance(mapping, ConformedFnirtMapping): + if to != "subject": + raise ValueError( + "Attempted to transform streamlines to template using " + "unsupported mapping. " + "Use something other than Fnirt.") + tg.to_vox() + moved_sl = [] + for sl in tg.streamlines: + moved_sl.append(mapping.transform_inverse_pts(sl)) + else: + tg.to_rasmm() + if to == "template": + volume = mapping.forward + else: + volume = mapping.backward + delta = dts.values_from_volume( + volume, + tg.streamlines, np.eye(4)) + moved_sl = dts.Streamlines( + [d + s for d, s in zip(delta, tg.streamlines)]) + moved_sft = StatefulTractogram( + moved_sl, + img, + Space.RASMM) + if save_intermediates is not None: + save_tractogram( + moved_sft, + op.join(save_intermediates, + f'sls_in_{to}.trk'), + bbox_valid_check=False) + tg.to_space(tg_og_space) + return moved_sft + + +def resample_tg(tg, n_points): + # reformat for dipy's set_number_of_points + if isinstance(tg, np.ndarray): + if len(tg.shape) > 2: + streamlines = tg.tolist() + streamlines = [np.asarray(item) for item in streamlines] + elif hasattr(tg, "streamlines"): + streamlines = tg.streamlines + else: + streamlines = tg + + return dps.set_number_of_points(streamlines, n_points) + + +class SlsBeingRecognized: + def __init__(self, sls, logger, save_intermediates, b_name, ref, + n_roi_dists): + self.oriented_yet = False + self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32) + self.sls_flipped = np.zeros(len(sls), dtype=np.bool8) + self.logger = logger + self.start_time = -1 + self.save_intermediates = save_intermediates + self.b_name = b_name + self.ref_sls = sls + self.ref = ref + self.n_roi_dists = n_roi_dists + + def initiate_selection(self, clean_name): + self.start_time = time() + self.logger.info(f"Filtering by {clean_name}") + return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool8) + + def select(self, idx, clean_name, cut=False): + self.selected_fiber_idxs = self.selected_fiber_idxs[idx] + self.sls_flipped = self.sls_flipped[idx] + if hasattr(self, "roi_dists"): + self.roi_dists = self.roi_dists[idx] + time_taken = time() - self.start_time + self.logger.info( + f"After filtering by {clean_name} (time: {time_taken}s), " + f"{len(self)} streamlines remain.") + if self.save_intermediates is not None: + save_tractogram( + StatefulTractogram( + self.get_selected_sls(cut=cut), + self.ref, Space.VOX), + op.join(self.save_intermediates, + f'sls_after_{clean_name}_for_{self.b_name}.trk'), + bbox_valid_check=False) + + def get_selected_sls(self, cut=False, flip=False): + selected_sls = self.ref_sls[self.selected_fiber_idxs] + if cut and hasattr(self, "roi_dists") and self.n_roi_dists > 1: + selected_sls = cut_sls_by_dist( + selected_sls, self.roi_dists, + (0, self.n_roi_dists - 1), + in_place=False) + if flip: + selected_sls = flip_sls( + selected_sls, self.sls_flipped, + in_place=False) + return selected_sls + + def reorient(self, idx): + if self.oriented_yet: + raise RuntimeError(( + "Attempted to oriented streamlines " + "that were already oriented. " + "This is a bug in the implementation of a " + "bundle recognition procedure. ")) + self.oriented_yet = True + self.sls_flipped[idx] = True + + def __bool__(self): + return len(self) > 0 + + def __len__(self): + return len(self.selected_fiber_idxs) diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index b0e3735a2..15d30c91c 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1283,7 +1283,7 @@ def read_hcp_atlas(n_bundles=16, as_file=False): for bundle_file in bundle_files: bundle = drop_extension(op.split(bundle_file)[-1]) centroid_file = op.join(centroid_folder, f"{bundle}.trk") - bundle_dict[bundle] = {} + bundle_dict[bundle] = {"recobundles": {}} if not op.exists(centroid_file): bundle_sl = load_tractogram( bundle_file, @@ -1299,16 +1299,16 @@ def read_hcp_atlas(n_bundles=16, as_file=False): centroid_file, bbox_valid_check=False) if not as_file: - bundle_dict[bundle]['sl'] = load_tractogram( + bundle_dict[bundle]["recobundles"]['sl'] = load_tractogram( bundle_file, 'same', bbox_valid_check=False).streamlines - bundle_dict[bundle]['centroid'] = load_tractogram( + bundle_dict[bundle]["recobundles"]['centroid'] = load_tractogram( centroid_file, "same", bbox_valid_check=False).streamlines else: - bundle_dict[bundle]['sl'] = bundle_file - bundle_dict[bundle]['centroid'] = centroid_file + bundle_dict[bundle]["recobundles"]['sl'] = bundle_file + bundle_dict[bundle]["recobundles"]['centroid'] = centroid_file # For some reason, this file-name has a 0 in it, instead of an O: bundle_dict["IFOF_R"] = bundle_dict["IF0F_R"] diff --git a/AFQ/segmentation.py b/AFQ/segmentation.py deleted file mode 100644 index 6c03986a1..000000000 --- a/AFQ/segmentation.py +++ /dev/null @@ -1,1444 +0,0 @@ -import os.path as op -import os -import logging -from time import time - -import numpy as np -from scipy.spatial.distance import cdist -from scipy.stats import zscore -from scipy.ndimage import binary_dilation - -import dipy.tracking.streamline as dts -import dipy.tracking.streamlinespeed as dps -from dipy.segment.bundles import RecoBundles -from dipy.stats.analysis import gaussian_weights -from dipy.io.stateful_tractogram import StatefulTractogram, Space -from dipy.io.streamline import save_tractogram, load_tractogram -from dipy.utils.parallel import paramap -from dipy.segment.clustering import QuickBundles -from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric -from dipy.segment.featurespeed import ResampleFeature - -import AFQ.utils.models as ut -import AFQ.data.fetch as afd -from AFQ.data.utils import BUNDLE_RECO_2_AFQ -from AFQ.api.bundle_dict import BundleDict -from AFQ.definitions.mapping import ConformedFnirtMapping -from AFQ._fixes import gaussian_weights - -__all__ = ["Segmentation", "clean_bundle", "clean_by_endpoints"] - - -logger = logging.getLogger('AFQ') - - -def _resample_tg(tg, n_points): - # reformat for dipy's set_number_of_points - if isinstance(tg, np.ndarray): - if len(tg.shape) > 2: - streamlines = tg.tolist() - streamlines = [np.asarray(item) for item in streamlines] - elif hasattr(tg, "streamlines"): - streamlines = tg.streamlines - else: - streamlines = tg - - return dps.set_number_of_points(streamlines, n_points) - - -class _SlsBeingRecognized: - def __init__(self, sls, logger, save_intermediates, b_name, ref, - n_roi_dists): - self.oriented_yet = False - self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32) - self.sls_flipped = np.zeros(len(sls), dtype=np.bool8) - self.bundle_vote = np.full(len(sls), -np.inf, dtype=np.float32) - self.logger = logger - self.start_time = -1 - self.save_intermediates = save_intermediates - self.b_name = b_name - self.ref_sls = sls - self.ref = ref - self.n_roi_dists = n_roi_dists - - def initiate_selection(self, clean_name): - self.start_time = time() - self.logger.info(f"Filtering by {clean_name}") - return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool8) - - def select(self, idx, clean_name, cut=False): - self.selected_fiber_idxs = self.selected_fiber_idxs[idx] - self.sls_flipped = self.sls_flipped[idx] - self.bundle_vote = self.bundle_vote[idx] - if hasattr(self, "roi_dists"): - self.roi_dists = self.roi_dists[idx] - time_taken = time() - self.start_time - self.logger.info( - f"After filtering by {clean_name} (time: {time_taken}s), " - f"{len(self)} streamlines remain.") - if self.save_intermediates is not None: - save_tractogram( - StatefulTractogram( - self.get_selected_sls(cut=cut), - self.ref, Space.VOX), - op.join(self.save_intermediates, - f'sls_after_{clean_name}_for_{self.b_name}.trk'), - bbox_valid_check=False) - - def get_selected_sls(self, cut=False, flip=False): - selected_sls = self.ref_sls[self.selected_fiber_idxs] - if cut and hasattr(self, "roi_dists") and self.n_roi_dists > 1: - selected_sls = _cut_sls_by_dist( - selected_sls, self.roi_dists, - (0, self.n_roi_dists - 1), - in_place=False) - if flip: - selected_sls = _flip_sls( - selected_sls, self.sls_flipped, - in_place=False) - return selected_sls - - def reorient(self, idx): - if self.oriented_yet: - raise RuntimeError(( - "Attempted to oriented streamlines " - "that were already oriented. " - "This is a bug in the implementation of a " - "bundle recognition procedure. ")) - self.oriented_yet = True - self.sls_flipped[idx] = True - - def __bool__(self): - return len(self) > 0 - - def __len__(self): - return len(self.selected_fiber_idxs) - - -class Segmentation: - def __init__(self, - nb_points=False, - nb_streamlines=False, - seg_algo='AFQ', - clip_edges=False, - parallel_segmentation={"engine": "serial"}, - progressive=True, - greater_than=50, - rm_small_clusters=50, - model_clust_thr=1.25, - reduction_thr=25, - refine=False, - pruning_thr=12, - b0_threshold=50, - prob_threshold=0, - roi_dist_tie_break=False, - dist_to_waypoint=None, - rng=None, - return_idx=False, - presegment_bundle_dict=None, - presegment_kwargs={}, - filter_by_endpoints=True, - dist_to_atlas=4, - save_intermediates=None, - cleaning_params={}): - """ - Segment streamlines into bundles. - - Parameters - ---------- - nb_points : int, boolean - Resample streamlines to nb_points number of points. - If False, no resampling is done. Default: False - nb_streamlines : int, boolean - Subsample streamlines to nb_streamlines. - If False, no subsampling is don. Default: False - seg_algo : string - Algorithm for segmentation (case-insensitive): - 'AFQ': Segment streamlines into bundles, - based on inclusion/exclusion ROIs. - 'Reco': Segment streamlines using the RecoBundles algorithm - [Garyfallidis2017]. - Default: 'AFQ' - clip_edges : bool - Whether to clip the streamlines to be only in between the ROIs. - Default: False - parallel_segmentation : dict or AFQ.api.BundleDict - How to parallelize segmentation across processes when performing - waypoint ROI segmentation. Set to {"engine": "serial"} to not - perform parallelization. Some engines may cause errors, depending - on the system. See ``dipy.utils.parallel.paramap`` for - details. - Default: {"engine": "serial"} - rm_small_clusters : int - Using RecoBundles Algorithm. - Remove clusters that have less than this value - during whole brain SLR. - Default: 50 - model_clust_thr : int - Parameter passed on to recognize for Recobundles. - See Recobundles documentation. - Default: 1.25 - reduction_thr : int - Parameter passed on to recognize for Recobundles. - See Recobundles documentation. - Default: 25 - refine : bool - Parameter passed on to recognize for Recobundles. - See Recobundles documentation. - Default: False - pruning_thr : int - Parameter passed on to recognize for Recobundles. - See Recobundles documentation. - Default: 12 - progressive : boolean, optional - Using RecoBundles Algorithm. - Whether or not to use progressive technique - during whole brain SLR. - Default: True. - greater_than : int, optional - Using RecoBundles Algorithm. - Keep streamlines that have length greater than this value - during whole brain SLR. - Default: 50. - b0_threshold : float. - Using AFQ Algorithm. - All b-values with values less than or equal to `bo_threshold` are - considered as b0s i.e. without diffusion weighting. - Default: 50. - prob_threshold : float. - Using AFQ Algorithm. - Initial cleaning of fiber groups is done using probability maps - from [Hua2008]_. Here, we choose an average probability that - needs to be exceeded for an individual streamline to be retained. - Default: 0. - roi_dist_tie_break : bool. - Whether to use distance from nearest ROI as a tie breaker when a - streamline qualifies as a part of multiple bundles. If False, - probability maps are used. - Default : False. - dist_to_waypoint : float. - The distance that a streamline node has to be from the waypoint - ROI in order to be included or excluded. - If set to None (default), will be calculated as the - center-to-corner distance of the voxel in the diffusion data. - If a bundle has inc_addtol or exc_addtol in its bundle_dict, that - tolerance will be added to this distance. - For example, if you wanted to increase tolerance for the right - arcuate waypoint ROIs by 3 each, you could make the following - modification to your bundle_dict: - bundle_dict["Right Arcuate"]["inc_addtol"] = [3, 3] - Additional tolerances can also be negative. - rng : RandomState or int - If None, creates RandomState. - If int, creates RandomState with seed rng. - Used in RecoBundles Algorithm. - Default: None. - return_idx : bool - Whether to return the indices in the original streamlines as part - of the output of segmentation. - presegment_bundle_dict : dict or None - If not None, presegment by ROIs before performing - RecoBundles. Only used if seg_algo starts with 'Reco'. - Meta-data for the segmentation. The format is something like:: - {'bundle_name': { - 'include':[img1, img2], - 'prob_map': img3, - 'cross_midline': False, - 'start': img4, - 'end': img5}} - Default: None - presegment_kwargs : dict - Optional arguments for initializing the segmentation for the - presegmentation. Only used if presegment_bundle_dict is not None. - Default: {} - filter_by_endpoints: bool - Whether to filter the bundles based on their endpoints. - Applies only when `seg_algo == 'AFQ'`. - Default: True. - dist_to_atlas : float - If filter_by_endpoints is True, this is the required distance - from the endpoints to the atlas ROIs. - save_intermediates : str, optional - The full path to a folder into which intermediate products - are saved. Default: None, means no saving of intermediates. - cleaning_params : dict, optional - Cleaning params to pass to seg.clean_bundle. This will - override the default parameters of that method. However, this - can be overriden by setting the cleaning parameters in the - bundle_dict. Default: {}. - - References - ---------- - .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008) - Tract probability maps in stereotaxic spaces: analyses of white - matter anatomy and tract-specific quantification. Neuroimage 39: - 336-347 - """ - self.logger = logger - self.nb_points = nb_points - self.nb_streamlines = nb_streamlines - - if rng is None: - self.rng = np.random.RandomState() - elif isinstance(rng, int): - self.rng = np.random.RandomState(rng) - else: - self.rng = rng - - self.seg_algo = seg_algo.lower() - self.prob_threshold = prob_threshold - self.roi_dist_tie_break = roi_dist_tie_break - self.dist_to_waypoint = dist_to_waypoint - self.b0_threshold = b0_threshold - self.progressive = progressive - self.greater_than = greater_than - self.rm_small_clusters = rm_small_clusters - self.model_clust_thr = model_clust_thr - self.reduction_thr = reduction_thr - self.refine = refine - self.pruning_thr = pruning_thr - self.return_idx = return_idx - self.presegment_bundle_dict = presegment_bundle_dict - self.presegment_kwargs = presegment_kwargs - self.filter_by_endpoints = filter_by_endpoints - self.dist_to_atlas = dist_to_atlas - self.parallel_segmentation = parallel_segmentation - self.cleaning_params = cleaning_params - - if (save_intermediates is not None) and \ - (not op.exists(save_intermediates)): - os.makedirs(save_intermediates, exist_ok=True) - self.save_intermediates = save_intermediates - self.clip_edges = clip_edges - - def _read_tg(self, tg=None): - if tg is None: - tg = self.tg - else: - self.tg = tg - self._tg_orig_space = self.tg.space - - if self.nb_streamlines and len(self.tg) > self.nb_streamlines: - self.tg = StatefulTractogram.from_sft( - dts.select_random_set_of_streamlines( - self.tg.streamlines, - self.nb_streamlines - ), - self.tg - ) - - return tg - - def segment(self, bundle_dict, tg, mapping, img, - reg_prealign=None, - reg_template=None, reset_tg_space=False): - """ - Segment streamlines into bundles based on either waypoint ROIs - [Yeatman2012]_ or RecoBundles [Garyfallidis2017]_. - Parameters - ---------- - bundle_dict: dict or AFQ.api.BundleDict - Meta-data for the segmentation. The format is something like:: - {'bundle_name': { - 'include':[img1, img2], - 'prob_map': img3, - 'cross_midline': False, - 'start': img4, - 'end': img5}} - tg : StatefulTractogram - Bundles to segment - mapping : DiffeomorphicMap, or equivalent interface - A mapping between DWI space and a template. - img : Nifti1Image - Image to use as reference. - reg_prealign : array, optional. - The linear transformation to be applied to align input images to - the reference space before warping under the deformation field. - Default: None. - reg_template : str or nib.Nifti1Image, optional. - Template to use for registration. Default: MNI T2. - reset_tg_space : bool, optional - Whether to reset the space of the input tractogram after - segmentation is complete. Default: False. - - Returns - ------- - dict : Where keys are bundle names, values are tractograms of - these bundles. - - References - ---------- - .. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J. - Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles of - White Matter Properties: Automating Fiber-Tract Quantification" - PloS One 7 (11): e49790. - .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter - bundles using local and global streamline-based registration and - clustering, Neuroimage, 2017. - """ - self.img_affine = img.affine - self.img = img - self.logger.info("Preprocessing Streamlines") - tg = self._read_tg(tg) - - # These are calculated as-needed - self._fg_array = None - self._crosses = None - - # If resampling over-write the sft: - if self.nb_points: - self.tg = StatefulTractogram( - dps.set_number_of_points(self.tg.streamlines, self.nb_points), - self.tg, self.tg.space) - - if reg_template is None: - reg_template = afd.read_mni_template() - - self.reg_prealign = reg_prealign - self.reg_template = reg_template - self.mapping = mapping - self.bundle_dict = bundle_dict - if not isinstance(self.bundle_dict, BundleDict): - self.bundle_dict = BundleDict(self.bundle_dict) - - if self.seg_algo == "afq": - fiber_groups = self.segment_afq() - elif self.seg_algo.startswith("reco"): - fiber_groups = self.segment_reco() - else: - raise ValueError(f"The seg_algo input is {self.seg_algo}, which", - "is not recognized") - if reset_tg_space: - # Return the input to the original space when you are done: - self.tg.to_space(self._tg_orig_space) - - return fiber_groups - - @property - def fgarray(self): - """ - Streamlines resampled to 20 points. - """ - if self._fg_array is None: - self.logger.info("Resampling Streamlines...") - start_time = time() - self._fg_array = np.array(_resample_tg(self.tg, 20)) - self.logger.info(( - "Streamlines Resampled " - f"(time: {time()-start_time}s)")) - return self._fg_array - - @property - def crosses(self): - """ - Classify the streamlines by whether they cross the midline. - Creates a crosses attribute which is an array of booleans. Each boolean - corresponds to a streamline, and is whether or not that streamline - crosses the midline. - """ - if self._crosses is None: - # What is the x,y,z coordinate of 0,0,0 in the template space? - zero_coord = np.dot(np.linalg.inv(self.img_affine), - np.array([0, 0, 0, 1])) - - self._crosses = np.logical_and( - np.any(self.fgarray[:, :, 0] > zero_coord[0], axis=1), - np.any(self.fgarray[:, :, 0] < zero_coord[0], axis=1)) - return self._crosses - - def _return_empty(self, bundle): - """ - Helper function for segment_afq, to return an empty dict under - some conditions. - """ - - if self.return_idx: - self.fiber_groups[bundle] = {} - self.fiber_groups[bundle]['sl'] = StatefulTractogram( - [], self.img, Space.VOX) - self.fiber_groups[bundle]['idx'] = np.array([]) - else: - self.fiber_groups[bundle] = StatefulTractogram( - [], self.img, Space.VOX) - - def _add_bundle_to_fiber_group(self, b_name, sl, idx, to_flip): - """ - Helper function for segment_afq, to add a bundle - to a fiber group. - """ - sl = _flip_sls( - sl, to_flip, - in_place=False) - - sl = StatefulTractogram( - sl, - self.img, - Space.VOX) - - if self.return_idx: - self.fiber_groups[b_name] = {} - self.fiber_groups[b_name]['sl'] = sl - self.fiber_groups[b_name]['idx'] = idx - else: - self.fiber_groups[b_name] = sl - - def _add_bundle_to_meta(self, bundle_name, bundle_info): - bundle_info = bundle_info.copy() - # remove keys that can never be serialized - for key in [ - 'include', 'exclude', 'prob_map', - 'start', 'end', 'curvature']: - bundle_info.pop(key, None) - self.meta[bundle_name] = bundle_info - - def segment_afq(self, tg=None): - """ - Assign streamlines to bundles using the waypoint ROI approach - Parameters - ---------- - tg : StatefulTractogram class instance - """ - tg = self._read_tg(tg=tg) - tg.to_vox() - - n_streamlines = len(tg) - - bundle_votes = np.full( - (n_streamlines, len(self.bundle_dict)), - -np.inf, dtype=np.float32) - bundle_to_flip = np.zeros( - (n_streamlines, len(self.bundle_dict)), - dtype=np.bool8) - - bundle_roi_dists = -np.ones( - ( - n_streamlines, - len(self.bundle_dict), - self.bundle_dict.max_includes), - dtype=np.uint32) - - self.fiber_groups = {} - self.meta = {} - - # We need to calculate the size of a voxel, so we can transform - # from mm to voxel units: - R = self.img_affine[0:3, 0:3] - vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) - - # Tolerance is set to the square of the distance to the corner - # because we are using the squared Euclidean distance in calls to - # `cdist` to make those calls faster. - if self.dist_to_waypoint is None: - tol = dts.dist_to_corner(self.img_affine) - else: - tol = self.dist_to_waypoint / vox_dim - dist_to_atlas = int(self.dist_to_atlas / vox_dim) - - self.logger.info("Assigning Streamlines to Bundles") - for bundle_idx, bundle_name in enumerate( - self.bundle_dict.bundle_names): - self.logger.info(f"Finding Streamlines for {bundle_name}") - - # Warp ROIs - self.logger.info(f"Preparing ROIs for {bundle_name}") - start_time = time() - bundle_def = dict(self.bundle_dict.get_b_info(bundle_name)) - bundle_def.update(self.bundle_dict.transform_rois( - bundle_name, - self.mapping, - self.img_affine)) - self.logger.info(f"Time to prep ROIs: {time()-start_time}s") - - if "curvature" in bundle_def: - self.logger.info(f"Loading curvature...") - start_time = time() - if "sft" in bundle_def["curvature"]: - ref_sl = bundle_def["curvature"]["sft"] - else: - ref_sl = load_tractogram( - bundle_def["curvature"]["path"], "same", - bbox_valid_check=False) - moved_ref_sl = self.move_streamlines( - ref_sl, "subject") - moved_ref_sl.to_vox() - moved_ref_sl = moved_ref_sl.streamlines[0] - moved_ref_curve = sl_curve( - moved_ref_sl, - len(moved_ref_sl)) - self.logger.info(( - "Time to load curves: " - f"{time()-start_time}s")) - - b_sls = _SlsBeingRecognized( - tg.streamlines, self.logger, - self.save_intermediates, bundle_name, - self.img, len(bundle_def.get("include", []))) - - # filter by probability map - if "prob_map" in bundle_def: - b_sls.initiate_selection("Prob. Map") - # using entire fgarray here only because it is the first step - fiber_probabilities = dts.values_from_volume( - bundle_def["prob_map"].get_fdata(), - self.fgarray, np.eye(4)) - fiber_probabilities = np.mean(fiber_probabilities, -1) - if not self.roi_dist_tie_break: - b_sls.bundle_vote = fiber_probabilities - b_sls.select( - fiber_probabilities > self.prob_threshold, - "Prob. Map") - elif not self.roi_dist_tie_break: - b_sls.bundle_vote = np.ones(len(b_sls)) - - if b_sls and "cross_midline" in bundle_def: - b_sls.initiate_selection("Cross Mid.") - accepted = self.crosses[b_sls.selected_fiber_idxs] - if not bundle_def["cross_midline"]: - accepted = np.invert(accepted) - b_sls.select(accepted, "Cross Mid.") - - if b_sls and "start" in bundle_def: - accept_idx = b_sls.initiate_selection("startpoint") - clean_by_endpoints( - b_sls.get_selected_sls(), - bundle_def["start"], - 0, - tol=dist_to_atlas, - flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx) - if not b_sls.oriented_yet: - accepted_idx_flipped = clean_by_endpoints( - b_sls.get_selected_sls(), - bundle_def["start"], - -1, - tol=dist_to_atlas) - b_sls.reorient(accepted_idx_flipped) - accept_idx = np.logical_xor( - accepted_idx_flipped, accept_idx) - b_sls.select(accept_idx, "startpoint") - - if b_sls and "end" in bundle_def: - accept_idx = b_sls.initiate_selection("endpoint") - cleaned_idx = clean_by_endpoints( - b_sls.get_selected_sls(), - bundle_def["end"], - -1, - tol=dist_to_atlas, - flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx) - if not b_sls.oriented_yet: - accepted_idx_flipped = clean_by_endpoints( - b_sls.get_selected_sls(), - bundle_def["end"], - 0, - tol=dist_to_atlas) - b_sls.reorient(accepted_idx_flipped) - accept_idx = np.logical_xor( - accepted_idx_flipped, accept_idx) - b_sls.select(accept_idx, "endpoint") - - if b_sls and ( - ("min_len" in bundle_def) or ("max_len" in bundle_def)): - accept_idx = b_sls.initiate_selection("length") - min_len = bundle_def.get("min_len", 0) / vox_dim - max_len = bundle_def.get("max_len", np.inf) / vox_dim - for idx, sl in enumerate(b_sls.get_selected_sls()): - sl_len = np.sum( - np.linalg.norm(np.diff(sl, axis=0), axis=1)) - if sl_len >= min_len and sl_len <= max_len: - accept_idx[idx] = 1 - b_sls.select(accept_idx, "length") - - if b_sls and "primary_axis" in bundle_def: - b_sls.initiate_selection("orientation") - accept_idx = clean_by_orientation( - b_sls.get_selected_sls(), - bundle_def["primary_axis"], - bundle_def.get( - "primary_axis_percentage", None)) - b_sls.select(accept_idx, "orientation") - - if b_sls and "include" in bundle_def: - accept_idx = b_sls.initiate_selection("include") - flip_using_include = len(bundle_def["include"]) > 1\ - and not b_sls.oriented_yet - - if f'inc_addtol' in bundle_def: - include_roi_tols = [] - for inc_tol in bundle_def["inc_addtol"]: - include_roi_tols.append((inc_tol / vox_dim + tol)**2) - else: - include_roi_tols = [tol**2] * len(bundle_def["include"]) - - include_rois = [] - for include_roi in bundle_def["include"]: - include_rois.append(np.array( - np.where(include_roi.get_fdata())).T) - - # with parallel segmentation, the first for loop will - # only collect streamlines and does not need tqdm - if self.parallel_segmentation["engine"] != "serial": - inc_results = paramap( - _check_sl_with_inclusion, b_sls.get_selected_sls(), - func_args=[ - include_rois, include_roi_tols], - **self.parallel_segmentation) - - else: - inc_results = _check_sls_with_inclusion( - b_sls.get_selected_sls(), - include_rois, - include_roi_tols) - - if self.roi_dist_tie_break: - min_dist_coords = np.ones(len(b_sls)) - roi_dists = -np.ones( - (len(b_sls), self.bundle_dict.max_includes), - dtype=np.int32) - if flip_using_include: - to_flip = np.ones_like(accept_idx, dtype=np.bool8) - for sl_idx, inc_result in enumerate(inc_results): - sl_accepted, sl_dist = inc_result - - if sl_accepted: - if self.roi_dist_tie_break: - min_dist_coords[sl_idx] = np.min(sl_dist) - - if len(sl_dist) > 1: - roi_dists[sl_idx, :len(sl_dist)] = [ - np.argmin(dist, 0)[0] - for dist in sl_dist] - first_roi_idx = roi_dists[sl_idx, 0] - last_roi_idx = roi_dists[ - sl_idx, len(sl_dist) - 1] - # Only accept SLs that, when cut, are meaningful - if (len(sl_dist) < 2) or abs( - first_roi_idx - last_roi_idx) > 1: - # Flip sl if it is close to second ROI - # before its close to the first ROI - if flip_using_include: - to_flip[sl_idx] =\ - first_roi_idx > last_roi_idx - if to_flip[sl_idx]: - roi_dists[sl_idx, :len(sl_dist)] =\ - np.flip(roi_dists[ - sl_idx, :len(sl_dist)]) - accept_idx[sl_idx] = 1 - else: - accept_idx[sl_idx] = 1 - # see https://github.com/joblib/joblib/issues/945 - if ( - (self.parallel_segmentation.get( - "engine", "joblib") != "serial") - and (self.parallel_segmentation.get( - "backend", "loky") == "loky")): - from joblib.externals.loky import get_reusable_executor - get_reusable_executor().shutdown(wait=True) - if self.roi_dist_tie_break: - b_sls.bundle_vote = -min_dist_coords - b_sls.roi_dists = roi_dists - if flip_using_include: - b_sls.reorient(to_flip) - b_sls.select(accept_idx, "include") - - # Filters streamlines by how well they match - # a curve in orientation and shape but not scale - if b_sls and "curvature" in bundle_def: - accept_idx = b_sls.initiate_selection("curvature") - ref_curve_threshold = np.radians(bundle_def["curvature"].get( - "thresh", 10)) - cut = bundle_def["curvature"].get("cut", True) - for idx, sl in enumerate(b_sls.get_selected_sls( - cut=cut, flip=True)): - if len(sl) > 1: - this_sl_curve = sl_curve(sl, len(moved_ref_sl)) - dist = sl_curve_dist(this_sl_curve, moved_ref_curve) - if dist <= ref_curve_threshold: - accept_idx[idx] = 1 - b_sls.select(accept_idx, "curvature", cut=cut) - - if b_sls and "exclude" in bundle_def: - accept_idx = b_sls.initiate_selection("exclude") - if f'exc_addtol' in bundle_def: - exclude_roi_tols = [] - for exc_tol in bundle_def["exc_addtol"]: - exclude_roi_tols.append((exc_tol / vox_dim + tol)**2) - else: - exclude_roi_tols = [tol**2] * len(bundle_def["exclude"]) - exclude_rois = [] - for exclude_roi in bundle_def["exclude"]: - exclude_rois.append(np.array( - np.where(exclude_roi.get_fdata())).T) - for sl_idx, sl in enumerate(b_sls.get_selected_sls()): - if _check_sl_with_exclusion( - sl, exclude_rois, exclude_roi_tols): - accept_idx[sl_idx] = 1 - b_sls.select(accept_idx, "exclude") - - if b_sls and "qb_thresh" in bundle_def: - b_sls.initiate_selection("qb_thresh") - cut = self.clip_edges or ("bundlesection" in bundle_def) - qbx = QuickBundles( - bundle_def["qb_thresh"] / vox_dim, - AveragePointwiseEuclideanMetric( - ResampleFeature(nb_points=12))) - clusters = qbx.cluster(b_sls.get_selected_sls( - cut=cut, flip=True)) - cleaned_idx = clusters[np.argmax( - clusters.clusters_sizes())].indices - b_sls.select(cleaned_idx, "qb_thresh", cut=cut) - - if b_sls: - accept_idx = b_sls.initiate_selection("Mahalanobis") - clean_params = bundle_def.get("mahal", {}) - clean_params = { - **self.cleaning_params, - **clean_params} - clean_params["return_idx"] = True - cut = self.clip_edges or ("bundlesection" in bundle_def) - _, cleaned_idx = clean_bundle( - b_sls.get_selected_sls(cut=cut, flip=True), - **clean_params) - b_sls.select(cleaned_idx, "Mahalanobis", cut=cut) - - if b_sls and not b_sls.oriented_yet: - raise ValueError( - "pyAFQ was unable to consistently orient streamlines " - f"in bundle {bundle_name} using the provided ROIs. " - "This can be fixed by including at least 2 " - "waypoint ROIs, or by using " - "endpoint ROIs.") - - if b_sls: - bundle_votes[ - b_sls.selected_fiber_idxs, - bundle_idx] = b_sls.bundle_vote.copy() - bundle_to_flip[ - b_sls.selected_fiber_idxs, - bundle_idx] = b_sls.sls_flipped.copy() - if hasattr(b_sls, "roi_dists"): - bundle_roi_dists[ - b_sls.selected_fiber_idxs, - bundle_idx - ] = b_sls.roi_dists.copy() - - if self.save_intermediates is not None: - os.makedirs(self.save_intermediates, exist_ok=True) - bc_path = op.join(self.save_intermediates, - "sls_bundle_votes.npy") - np.save(bc_path, bundle_votes) - - bundle_choice = np.argmax(bundle_votes, -1) - bundle_choice[bundle_votes.max(-1) == -np.inf] = -1 - - # We do another round through, so that we can orient all the - # streamlines within a bundle in the same orientation with respect to - # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0 - # to ROI1). - self.logger.info("Re-orienting streamlines to consistent directions") - for bundle_idx, bundle in enumerate(self.bundle_dict.bundle_names): - self.logger.info(f"Processing {bundle}") - - select_idx = np.where(bundle_choice == bundle_idx)[0] - - if len(select_idx) == 0: - # There's nothing here, set and move to the next bundle: - if "bundlesection" in self.bundle_dict.get_b_info(bundle): - for sb_name in self.bundle_dict.get_b_info(bundle)[ - "bundlesection"]: - self._return_empty(sb_name) - else: - self._return_empty(bundle) - continue - - # Use a list here, because ArraySequence doesn't support item - # assignment: - select_sl = list(tg.streamlines[select_idx]) - roi_dists = bundle_roi_dists[select_idx, bundle_idx, :] - n_includes = len(self.bundle_dict.get_b_info( - bundle).get("include", [])) - if self.clip_edges and n_includes > 1: - self.logger.info("Clipping Streamlines by ROI") - _cut_sls_by_dist( - select_sl, roi_dists, - (0, n_includes - 1), in_place=True) - - to_flip = bundle_to_flip[select_idx, bundle_idx] - if "bundlesection" in self.bundle_dict[bundle]: - for sb_name, sb_include_cuts in self.bundle_dict.get_b_info( - bundle)["bundlesection"].items(): - bundlesection_select_sl = _cut_sls_by_dist( - select_sl, roi_dists, - sb_include_cuts, in_place=False) - self._add_bundle_to_fiber_group( - sb_name, bundlesection_select_sl, select_idx, to_flip) - self._add_bundle_to_meta(sb_name, bundle_def) - else: - self._add_bundle_to_fiber_group( - bundle, select_sl, select_idx, to_flip) - self._add_bundle_to_meta(bundle, bundle_def) - return self.fiber_groups, self.meta - - def move_streamlines(self, tg, to="template"): - """Streamline-based registration of a whole-brain tractogram to - the MNI whole-brain atlas. - - to : str - "template" or "subject" - """ - tg_og_space = tg.space - if isinstance(self.mapping, ConformedFnirtMapping): - if to != "subject": - raise ValueError( - "Attempted to transform streamlines to template using " - "unsupported mapping. " - "Use something other than Fnirt.") - tg.to_vox() - moved_sl = [] - for sl in tg.streamlines: - moved_sl.append(self.mapping.transform_inverse_pts(sl)) - else: - tg.to_rasmm() - if to == "template": - volume = self.mapping.forward - else: - volume = self.mapping.backward - delta = dts.values_from_volume( - volume, - tg.streamlines, np.eye(4)) - moved_sl = dts.Streamlines( - [d + s for d, s in zip(delta, tg.streamlines)]) - if to == "template": - ref = self.reg_template - else: - ref = self.img - moved_sft = StatefulTractogram( - moved_sl, - ref, - Space.RASMM) - if self.save_intermediates is not None: - save_tractogram( - moved_sft, - op.join(self.save_intermediates, - f'sls_in_{to}.trk'), - bbox_valid_check=False) - tg.to_space(tg_og_space) - return moved_sft - - def segment_reco(self, tg=None): - """ - Segment streamlines using the RecoBundles algorithm [Garyfallidis2017] - Parameters - ---------- - tg : StatefulTractogram class instance - A whole-brain tractogram to be segmented. - Returns - ------- - fiber_groups : dict - Keys are names of the bundles, values are Streamline objects. - The streamlines in each object have all been oriented to have the - same orientation (using `dts.orient_by_streamline`). - """ - tg = self._read_tg(tg=tg) - fiber_groups = {} - - # We generate our instance of RB with the moved streamlines: - self.logger.info("Extracting Bundles") - # If doing a presegmentation based on ROIs then initialize - # that segmentation and segment using ROIs, else - # RecoBundles based on the whole brain tractogram - if self.presegment_bundle_dict is not None: - roiseg = Segmentation(**self.presegment_kwargs) - roiseg.segment( - self.presegment_bundle_dict, - self.tg, - self.mapping, - self.img, - reg_template=self.reg_template, - reg_prealign=self.reg_prealign) - roiseg_fg = roiseg.fiber_groups - else: - moved_sl = self.move_streamlines(tg).streamlines - rb = RecoBundles(moved_sl, verbose=False, rng=self.rng) - # Next we'll iterate over bundles, registering each one: - bundle_list = list(self.bundle_dict.keys()) - if 'whole_brain' in bundle_list: - bundle_list.remove('whole_brain') - - self.logger.info("Assigning Streamlines to Bundles") - for bundle in bundle_list: - self.logger.info(f"Finding streamlines for {bundle}") - b_info = self.bundle_dict[bundle] - model_sl = b_info['sl'] - - # If doing a presegmentation based on ROIs then initialize rb after - # Filtering the whole brain tractogram to pass through ROIs - if self.presegment_bundle_dict is not None: - afq_bundle_name = BUNDLE_RECO_2_AFQ.get(bundle, bundle) - if "return_idx" in self.presegment_kwargs\ - and self.presegment_kwargs["return_idx"]: - indiv_tg = roiseg_fg[afq_bundle_name]['sl'] - else: - indiv_tg = roiseg_fg[afq_bundle_name] - - if len(indiv_tg.streamlines) < 1: - self.logger.warning(( - f"No streamlines found by waypoint ROI " - f"pre-segmentation for {bundle}. Using entire" - f" tractography instead.")) - indiv_tg = tg - - # Now rb should be initialized based on the fiber group coming - # out of the roi segmentation - indiv_tg = StatefulTractogram( - indiv_tg.streamlines, - self.img, - Space.VOX) - indiv_tg.to_rasmm() - moved_sl = self.move_streamlines(indiv_tg).streamlines - rb = RecoBundles( - moved_sl, - verbose=False, - rng=self.rng) - if self.save_intermediates is not None: - if self.presegment_bundle_dict is not None: - moved_fname = f"{bundle}_presegmentation.trk" - else: - moved_fname = "whole_brain.trk" - moved_sft = StatefulTractogram( - moved_sl, - self.reg_template, - Space.RASMM) - save_tractogram( - moved_sft, - op.join(self.save_intermediates, - moved_fname), - bbox_valid_check=False) - model_sft = StatefulTractogram( - model_sl, - self.reg_template, - Space.RASMM) - save_tractogram( - model_sft, - op.join(self.save_intermediates, - f"{bundle}_model.trk"), - bbox_valid_check=False) - - # Either whole brain tracgtogram or roi presegmented fiber group - # goes to rb.recognize - _, rec_labels = rb.recognize(model_bundle=model_sl, - model_clust_thr=self.model_clust_thr, - reduction_thr=self.reduction_thr, - reduction_distance='mdf', - slr=True, - slr_metric='asymmetric', - pruning_distance='mdf') - - # Use the streamlines in the original space: - if self.presegment_bundle_dict is None: - recognized_sl = tg.streamlines[rec_labels] - else: - recognized_sl = indiv_tg.streamlines[rec_labels] - if self.refine and len(recognized_sl) > 0: - _, rec_labels = rb.refine(model_sl, recognized_sl, - self.model_clust_thr, - reduction_thr=self.reduction_thr, - pruning_thr=self.pruning_thr) - if self.presegment_bundle_dict is None: - recognized_sl = tg.streamlines[rec_labels] - else: - recognized_sl = indiv_tg.streamlines[rec_labels] - standard_sl = next(iter(b_info['centroid'])) - oriented_sl = dts.orient_by_streamline(recognized_sl, standard_sl) - - self.logger.info( - f"{len(oriented_sl)} streamlines selected with Recobundles") - if self.return_idx: - fiber_groups[bundle] = {} - fiber_groups[bundle]['idx'] = rec_labels - fiber_groups[bundle]['sl'] = StatefulTractogram(oriented_sl, - self.img, - Space.RASMM) - else: - fiber_groups[bundle] = StatefulTractogram(oriented_sl, - self.img, - Space.RASMM) - self.fiber_groups = fiber_groups - return fiber_groups, {} - - -def sl_curve(sl, n_points): - """ - Calculate the direction of the displacement between - each point along a streamline - - Parameters - ---------- - sl : 2d array-like - Streamline to calcualte displacements for. - n_points : int - Number of points to resample the streamline to - - Returns - ------- - 2d array of shape (len(sl)-1, 3) with displacements - between each point in sl normalized to 1. - """ - # Resample to a standardized number of points - resampled_sl = dps.set_number_of_points( - sl, - n_points) - - # displacement at each point - resampled_sl_diff = np.diff(resampled_sl, axis=0) - - # normalize this displacement - resampled_sl_diff = resampled_sl_diff / np.linalg.norm( - resampled_sl_diff, axis=1)[:, None] - - return resampled_sl_diff - - -def sl_curve_dist(curve1, curve2): - """ - Calculate the mean angle using the directions of displacement - between two streamlines - - Parameters - ---------- - curve1, curve2 : 2d array-like - Two curves calculated from sl_curve. - - Returns - ------- - The mean angle between each curve across all steps, in radians - """ - return np.mean(np.arccos(np.sum(curve1 * curve2, axis=1))) - - -def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, - length_threshold=4, min_sl=20, stat='mean', - return_idx=False): - """ - Clean a segmented fiber group based on the Mahalnobis distance of - each streamline - - Parameters - ---------- - tg : StatefulTractogram class instance or ArraySequence - A whole-brain tractogram to be segmented. - n_points : int, optional - Number of points to resample streamlines to. - Default: 100 - clean_rounds : int, optional. - Number of rounds of cleaning based on the Mahalanobis distance from - the mean of extracted bundles. Default: 5 - distance_threshold : float, optional. - Threshold of cleaning based on the Mahalanobis distance (the units are - standard deviations). Default: 3. - length_threshold: float, optional - Threshold for cleaning based on length (in standard deviations). Length - of any streamline should not be *more* than this number of stdevs from - the mean length. - min_sl : int, optional. - Number of streamlines in a bundle under which we will - not bother with cleaning outliers. Default: 20. - stat : callable or str, optional. - The statistic of each node relative to which the Mahalanobis is - calculated. Default: `np.mean` (but can also use median, etc.) - return_idx : bool - Whether to return indices in the original streamlines. - Default: False. - Returns - ------- - A StatefulTractogram class instance containing only the streamlines - that have a Mahalanobis distance smaller than `clean_threshold` from - the mean of each one of the nodes. - """ - # Convert string to callable, if that's what you got. - if isinstance(stat, str): - stat = getattr(np, stat) - - if hasattr(tg, "streamlines"): - streamlines = tg.streamlines - else: - streamlines = dts.Streamlines(tg) - - # We don't even bother if there aren't enough streamlines: - if len(streamlines) < min_sl: - if return_idx: - return tg, np.arange(len(streamlines)) - else: - return tg - - # Resample once up-front: - fgarray = np.asarray(_resample_tg(streamlines, n_points)) - - # Keep this around, so you can use it for indexing at the very end: - idx = np.arange(len(fgarray)) - # get lengths of each streamline - lengths = np.array([sl.shape[0] for sl in streamlines]) - # We'll only do this for clean_rounds - rounds_elapsed = 0 - idx_belong = idx - while (rounds_elapsed < clean_rounds) and (np.sum(idx_belong) > min_sl): - # Update by selection: - idx = idx[idx_belong] - fgarray = fgarray[idx_belong] - lengths = lengths[idx_belong] - rounds_elapsed += 1 - - # This calculates the Mahalanobis for each streamline/node: - m_dist = gaussian_weights( - fgarray, return_mahalnobis=True, - n_points=n_points, stat=stat) - logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}") - logger.debug(f"Shape of m_dist: {m_dist.shape}") - logger.debug(f"Maximum m_dist: {np.max(m_dist)}") - logger.debug(( - f"Maximum m_dist for each fiber: " - f"{np.max(m_dist, axis=1)}")) - - length_z = zscore(lengths) - logger.debug(f"Shape of length_z: {length_z.shape}") - logger.debug(f"Maximum length_z: {np.max(length_z)}") - logger.debug(( - "length_z for each fiber: " - f"{length_z}")) - - if not ( - np.any(m_dist > distance_threshold) - or np.any(length_z > length_threshold)): - break - # Select the fibers that have Mahalanobis smaller than the - # threshold for all their nodes: - idx_dist = np.all(m_dist < distance_threshold, axis=-1) - idx_len = length_z < length_threshold - idx_belong = np.logical_and(idx_dist, idx_len) - - if np.sum(idx_belong) < min_sl: - # need to sort and return exactly min_sl: - idx_belong = np.argsort(np.sum( - m_dist, axis=-1))[:min_sl].astype(int) - logger.debug(( - f"At rounds elapsed {rounds_elapsed}, " - "minimum streamlines reached")) - else: - idx_removed = idx_belong == 0 - logger.debug(( - f"Rounds elapsed: {rounds_elapsed}, " - f"num removed: {np.sum(idx_removed)}")) - logger.debug(f"Removed indicies: {np.where(idx_removed)[0]}") - - # Select based on the variable that was keeping track of things for us: - if hasattr(tg, "streamlines"): - out = StatefulTractogram(tg.streamlines[idx], tg, Space.VOX) - else: - out = streamlines[idx] - if return_idx: - return out, idx - else: - return out - -# Helper functions for segmenting using waypoint ROIs -# they are not a part of the class because we do not want -# copies of the class to be parallelized - - -def _check_sls_with_inclusion(sls, include_rois, include_roi_tols): - for sl in sls: - yield _check_sl_with_inclusion( - sl, - include_rois, - include_roi_tols) - - -def _check_sl_with_inclusion(sl, include_rois, - include_roi_tols): - """ - Helper function to check that a streamline is close to a list of - inclusion ROIS. - """ - dist = [] - for ii, roi in enumerate(include_rois): - # Use squared Euclidean distance, because it's faster: - dist.append(cdist(sl, roi, 'sqeuclidean')) - if np.min(dist[-1]) > include_roi_tols[ii]: - # Too far from one of them: - return False, [] - # Apparently you checked all the ROIs and it was close to all of them - return True, dist - - -def _check_sl_with_exclusion(sl, exclude_rois, - exclude_roi_tols): - """ Helper function to check that a streamline is not too close to a - list of exclusion ROIs. - """ - for ii, roi in enumerate(exclude_rois): - # Use squared Euclidean distance, because it's faster: - if np.min(cdist(sl, roi, 'sqeuclidean')) < exclude_roi_tols[ii]: - return False - # Either there are no exclusion ROIs, or you are not close to any: - return True - - -def _flip_sls(select_sl, idx_to_flip, in_place=False): - """ - Helper function to flip streamlines - """ - if in_place: - flipped_sl = select_sl - else: - flipped_sl = [None] * len(select_sl) - for ii, sl in enumerate(select_sl): - if idx_to_flip[ii]: - flipped_sl[ii] = sl[::-1] - else: - flipped_sl[ii] = sl - return flipped_sl - - -def _cut_sls_by_dist(select_sl, roi_dists, roi_idxs, - in_place=False): - """ - Helper function to cut streamlines according to which points - are closest to certain rois. - - Parameters - ---------- - select_sl, streamlines to cut - roi_dists, distances from a given streamline to a given inclusion roi - roi_idxs, two indices into the list of inclusion rois to use for the cut - in_place, whether to modify select_sl - """ - if in_place: - cut_sl = select_sl - else: - cut_sl = [None] * len(select_sl) - - for idx, this_sl in enumerate(select_sl): - if roi_idxs[0] == -1: - min0 = 0 - else: - min0 = int(roi_dists[idx, roi_idxs[0]]) - if roi_idxs[1] == -1: - min1 = len(this_sl) - else: - min1 = int(roi_dists[idx, roi_idxs[1]]) - - # handle if sls not flipped - if min0 > min1: - min0, min1 = min1, min0 - - # If the point that is closest to the first ROI - # is the same as the point closest to the second ROI, - # include the surrounding points to make a streamline. - if min0 == min1: - min1 = min1 + 1 - min0 = min0 - 1 - - cut_sl[idx] = this_sl[min0:min1] - - return cut_sl - - -def clean_by_orientation(streamlines, primary_axis, tol=None): - """ - Compute the cardinal orientation of each streamline - - Parameters - ---------- - streamlines : sequence of N by 3 arrays - Where N is number of nodes in the array, the collection of - streamlines to filter down to. - - Returns - ------- - cleaned_idx, indicies of streamlines that passed cleaning, - logical_and of other two returns - along_accepted_idx, indices of streamlines that passed - cleaning along the bundle - end_accepted_idx, indices of streamlines that passed - cleaning based on difference between endpoints of bundle - """ - axis_diff = np.zeros((len(streamlines), 3)) - endpoint_diff = np.zeros((len(streamlines), 3)) - for ii, sl in enumerate(streamlines): - # endpoint diff is between first and last - endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :]) - # axis diff is difference between the nodes, along - axis_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0) - - orientation_along = np.argmax(axis_diff, axis=1) - along_accepted_idx = orientation_along == primary_axis - if tol is not None: - percentage_primary = 100 * axis_diff[:, primary_axis] / np.sum( - axis_diff, axis=1) - logger.debug(( - "Maximum primary percentage found: " - f"{np.max(percentage_primary)}")) - along_accepted_idx = np.logical_and( - along_accepted_idx, percentage_primary > tol) - - orientation_end = np.argmax(endpoint_diff, axis=1) - end_accepted_idx = orientation_end == primary_axis - - cleaned_idx = np.logical_and( - along_accepted_idx, - end_accepted_idx) - - return cleaned_idx - - -def clean_by_endpoints(streamlines, target, target_idx, tol=0, - flip_sls=None, accepted_idxs=None): - """ - Clean a collection of streamlines based on an endpoint ROI. - Filters down to only include items that have their start or end points - close to the targets. - Parameters - ---------- - streamlines : sequence of N by 3 arrays - Where N is number of nodes in the array, the collection of - streamlines to filter down to. - target: Nifti1Image - Nifti1Image containing a boolean representation of the ROI. - target_idx: int. - Index within each streamline to check if within the target region. - Typically 0 for startpoint ROIs or -1 for endpoint ROIs. - If using flip_sls, this becomes (len(sl) - this_idx - 1) % len(sl) - tol : int, optional - A distance tolerance (in units that the coordinates - of the streamlines are represented in). Default: 0, which means that - the endpoint is exactly in the coordinate of the target ROI. - flip_sls : 1d array, optional - Length is len(streamlines), whether to flip the streamline. - accepted_idxs : 1d array, optional - Boolean array, where entries correspond to eachs streamline, - and streamlines that pass cleaning will be set to 1. - Yields - ------- - boolean array of streamlines that survive cleaning. - """ - if accepted_idxs is None: - accepted_idxs = np.zeros(len(streamlines), dtype=np.bool8) - - if flip_sls is None: - flip_sls = np.zeros(len(streamlines)) - flip_sls = flip_sls.astype(int) - - roi = target.get_fdata() - if tol > 0: - roi = binary_dilation( - roi, - iterations=tol) - - for ii, sl in enumerate(streamlines): - this_idx = target_idx - if flip_sls[ii]: - this_idx = (len(sl) - this_idx - 1) % len(sl) - xx, yy, zz = sl[this_idx].astype(int) - accepted_idxs[ii] = roi[xx, yy, zz] - - return accepted_idxs diff --git a/AFQ/tasks/data.py b/AFQ/tasks/data.py index 7b10e7c41..082fbc627 100644 --- a/AFQ/tasks/data.py +++ b/AFQ/tasks/data.py @@ -1064,13 +1064,7 @@ def get_bundle_dict(segmentation_params, " a dict, or a BundleDict")) if bundle_info is None: - if segmentation_params["seg_algo"] == "reco" or\ - segmentation_params["seg_algo"] == "reco16": - bundle_info = abd.reco_bd(16) - elif segmentation_params["seg_algo"] == "reco80": - bundle_info = abd.reco_bd(80) - else: - bundle_info = abd.default18_bd() + abd.callosal_bd() + bundle_info = abd.default18_bd() + abd.callosal_bd() use_brain_mask = True brain_mask = nib.load(brain_mask).get_fdata() @@ -1101,7 +1095,6 @@ def get_bundle_dict(segmentation_params, else: bundle_dict = abd.BundleDict( bundle_info, - seg_algo=segmentation_params["seg_algo"], resample_to=reg_template) return bundle_dict, reg_template diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index d20050f1e..9b6147cff 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -8,9 +8,9 @@ import pimms -from AFQ.tasks.decorators import as_file, as_img +from AFQ.tasks.decorators import as_file from AFQ.tasks.utils import get_fname, with_name, str_to_desc -import AFQ.segmentation as seg +from AFQ.bundle_rec.recognize import recognize from AFQ.utils.path import drop_extension, write_json import AFQ.utils.streamlines as aus from AFQ.tasks.utils import get_default_args @@ -66,13 +66,13 @@ def segment(data_imap, mapping_imap, logger.warning(f"{len(indices_to_remove)} invalid streamlines removed") start_time = time() - segmentation = seg.Segmentation(**segmentation_params) - bundles, bundle_meta = segmentation.segment( - bundle_dict, + bundles, bundle_meta = recognize( tg, - mapping_imap["mapping"], data_imap["dwi"], - reg_template=reg_template) + mapping_imap["mapping"], + bundle_dict, + reg_template, + **segmentation_params) seg_sft = aus.SegmentedSFT(bundles, Space.VOX) @@ -381,7 +381,7 @@ def get_segmentation_plan(kwargs): segment, tract_profiles]) - default_seg_params = get_default_args(seg.Segmentation.__init__) + default_seg_params = get_default_args(recognize) if "segmentation_params" in kwargs: for k in kwargs["segmentation_params"]: default_seg_params[k] = kwargs["segmentation_params"][k] diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index 4f3d6a12d..a0db74d4e 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -31,9 +31,8 @@ def get_fname(base_fname, suffix, fname = fname + ( f'_coordsys-RASMM_trkmethod-{directions+odf_model}' ) - if segmentation_params is not None and 'seg_algo' in segmentation_params: - seg_algo = segmentation_params['seg_algo'] - fname = fname + f"_recogmethod-{seg_algo}" + if segmentation_params is not None: + fname = fname + f"_recogmethod-AFQ" return fname + suffix diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 016ea48c9..c4bb0f5d4 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -562,8 +562,8 @@ def test_AFQ_reco(): preproc_pipeline='vistasoft', viz_backend_spec="plotly", profile_weights="median", + bundle_info=abd.reco_bd(16), segmentation_params={ - 'seg_algo': 'reco', 'rng': 42}) seg_sft = aus.SegmentedSFT.fromfile( @@ -589,8 +589,8 @@ def test_AFQ_reco80(): bids_path=bids_path, preproc_pipeline='vistasoft', tracking_params=tracking_params, + bundle_info=abd.reco_bd(16), segmentation_params={ - 'seg_algo': 'reco80', 'rng': 42}) seg_sft = aus.SegmentedSFT.fromfile( @@ -812,7 +812,6 @@ def test_AFQ_data_waypoint(): random_seeds=True, rng_seed=42) segmentation_params = dict(filter_by_endpoints=False, - seg_algo="AFQ", return_idx=True) afq_folder = op.join(bids_path, "derivatives/afq/sub-01/ses-01") diff --git a/AFQ/tests/test_segmentation.py b/AFQ/tests/test_segmentation.py index 0b368031a..41f02b56c 100644 --- a/AFQ/tests/test_segmentation.py +++ b/AFQ/tests/test_segmentation.py @@ -12,8 +12,12 @@ from dipy.io.stateful_tractogram import StatefulTractogram, Space import AFQ.data.fetch as afd -import AFQ.segmentation as seg import AFQ.registration as reg +import AFQ.bundle_rec.cleaning as abc +import AFQ.bundle_rec.curvature as abv +import AFQ.bundle_rec.utils as abu +import AFQ.bundle_rec.roi as abr +from AFQ.bundle_rec.recognize import recognize dpd.fetch_stanford_hardi() @@ -63,12 +67,12 @@ 'cross_midline': None}} def test_segment(): - segmentation = seg.Segmentation() - segmentation.segment(bundles, - tg, - mapping, - nib.load(hardi_fdata)) - fiber_groups = segmentation.fiber_groups + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundles, + reg_template) # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) @@ -82,7 +86,7 @@ def test_segment(): CST_R_sl.streamlines, np.eye(4)) npt.assert_almost_equal(tract_profile, np.ones(100)) - clean_sl = seg.clean_bundle(CST_R_sl) + clean_sl = abc.clean_bundle(CST_R_sl) npt.assert_equal(len(clean_sl), len(CST_R_sl)) @@ -101,12 +105,12 @@ def test_segment_no_prob(): templates['CST_roi2_R']], 'cross_midline': False}} - segmentation = seg.Segmentation() - segmentation.segment(bundles_no_prob, - tg, - mapping, - nib.load(hardi_fdata)) - fiber_groups = segmentation.fiber_groups + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundles_no_prob, + reg_template) # This condition should still hold npt.assert_equal(len(fiber_groups), 2) @@ -115,44 +119,29 @@ def test_segment_no_prob(): def test_segment_return_idx(): # Test with the return_idx kwarg set to True: - segmentation = seg.Segmentation(return_idx=True) - segmentation.segment(bundles, - tg, - mapping, - nib.load(hardi_fdata)) - fiber_groups = segmentation.fiber_groups + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundles, + reg_template, + return_idx=True) npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['Right Corticospinal']['sl']) > 0) npt.assert_(len(fiber_groups['Right Corticospinal']['idx']) > 0) -def test_segment_keep_space(): - # Test with the return_idx kwarg set to True: - segmentation = seg.Segmentation(return_idx=True) - # We move the tg to rasmm to make sure that it ends - # up there - tg.to_rasmm() - orig_space = tg.space - segmentation.segment(bundles, - tg, - mapping, - nib.load(hardi_fdata), - reset_tg_space=True) - - npt.assert_equal(tg.space, orig_space) - - def test_segment_sl_curve(): - sl_disp_0 = seg.sl_curve(streamlines[4], 4) + sl_disp_0 = abv.sl_curve(streamlines[4], 4) npt.assert_array_almost_equal( sl_disp_0, [[-0.236384, -0.763855, 0.60054 ], [ 0.232594, -0.867859, -0.439 ], [ 0.175343, 0.001082, -0.984507]]) - sl_disp_1 = seg.sl_curve(streamlines[2], 4) - mean_angle_diff = seg.sl_curve_dist(sl_disp_0, sl_disp_1) + sl_disp_1 = abv.sl_curve(streamlines[2], 4) + mean_angle_diff = abv.sl_curve_dist(sl_disp_0, sl_disp_1) npt.assert_almost_equal(mean_angle_diff, 1.701458, decimal=3) @@ -165,7 +154,7 @@ def test_segment_clip_edges(): bundle_roi_dists[4, :] = [5, 10, 15] bundle_roi_dists[10, :] = [3, 6, 9] bundle_roi_dists[11, :] = [10, 10, 10] - cut_sls = seg._cut_sls_by_dist( + cut_sls = abu.cut_sls_by_dist( accepted_sls, bundle_roi_dists[accepted_ix], [0, 2]) @@ -183,14 +172,13 @@ def test_segment_clip_edges(): @pytest.mark.nightly def test_segment_clip_edges_api(): # Test with the clip_edges kwarg set to True: - segmentation = seg.Segmentation(clip_edges=True) - - fiber_groups, _ = segmentation.segment( - bundles, + fiber_groups, _ = recognize( tg, + nib.load(hardi_fdata), mapping, - nib.load(hardi_fdata)) - + bundles, + reg_template, + clip_edges=True) npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['Right Corticospinal']) > 0) @@ -198,20 +186,19 @@ def test_segment_clip_edges_api(): def test_segment_reco(): # get bundles for reco method bundles_reco = afd.read_hcp_atlas(16) - bundle_names = ['whole_brain', 'CST_R', 'CST_L'] + bundle_names = ['CST_R', 'CST_L'] for key in list(bundles_reco): if key not in bundle_names: bundles_reco.pop(key, None) # Try recobundles method - segmentation = seg.Segmentation(seg_algo='Reco', - progressive=False, - greater_than=10, - rm_small_clusters=1, - rng=np.random.RandomState(seed=8)) - fiber_groups, _ = segmentation.segment( - bundles_reco, tg, mapping, - nib.load(hardi_fdata)) + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundles_reco, + reg_template, + rng=np.random.RandomState(seed=8)) # This condition should still hold npt.assert_equal(len(fiber_groups), 2) @@ -246,26 +233,23 @@ def test_clean_by_endpoints(): target_img_end = nib.Nifti1Image( np.logical_or(atlas==3, atlas==4).astype(np.float32), np.eye(4)) - clean_idx_start = list(seg.clean_by_endpoints( + clean_idx_start = list(abr.clean_by_endpoints( sl, target_img_start, 0)) - clean_idx_end = list(seg.clean_by_endpoints( + clean_idx_end = list(abr.clean_by_endpoints( sl, target_img_end, -1)) npt.assert_array_equal(np.logical_and( clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0])) # If tol=1, the third streamline also gets included - clean_idx_start = list(seg.clean_by_endpoints( + clean_idx_start = list(abr.clean_by_endpoints( sl, target_img_start, 0, tol=1)) - clean_idx_end = list(seg.clean_by_endpoints( + clean_idx_end = list(abr.clean_by_endpoints( sl, target_img_end, -1, tol=1)) npt.assert_array_equal(np.logical_and( clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0])) def test_exclusion_ROI(): - segmentation = seg.Segmentation( - filter_by_endpoints=False - ) slf_bundle = { 'Left Superior Longitudinal': { 'include': [ @@ -286,50 +270,52 @@ def test_exclusion_ROI(): [30, 41, 62], [20, 44, 34]] ]).astype(float), hardi_img, Space.VOX) - fiber_groups, _ = segmentation.segment( - slf_bundle, + fiber_groups, _ = recognize( slf_tg, + nib.load(hardi_fdata), mapping, - nib.load(hardi_fdata)) + slf_bundle, + reg_template, + filter_by_endpoints=False) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 2) slf_bundle['Left Superior Longitudinal']['exclude'] =\ [templates["SLFt_roi2_L"]] - fiber_groups, _ = segmentation.segment( - slf_bundle, + fiber_groups, _ = recognize( slf_tg, + nib.load(hardi_fdata), mapping, - nib.load(hardi_fdata)) + slf_bundle, + reg_template, + filter_by_endpoints=False) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 1) def test_segment_orientation(): cleaned_idx = \ - seg.clean_by_orientation(streamlines, primary_axis=1) + abc.clean_by_orientation(streamlines, primary_axis=1) npt.assert_equal(np.sum(cleaned_idx), 93) cleaned_idx_tol = \ - seg.clean_by_orientation(streamlines, primary_axis=1, tol=50) + abc.clean_by_orientation(streamlines, primary_axis=1, tol=50) npt.assert_(np.sum(cleaned_idx_tol) < np.sum(cleaned_idx)) cleaned_idx = \ - seg.clean_by_orientation(streamlines, primary_axis=2) + abc.clean_by_orientation(streamlines, primary_axis=2) cleaned_idx_tol = \ - seg.clean_by_orientation(streamlines, primary_axis=2, tol=33) + abc.clean_by_orientation(streamlines, primary_axis=2, tol=33) npt.assert_array_equal(cleaned_idx_tol, cleaned_idx) def test_segment_sampled_streamlines(): - - # default segmentation - segmentation = seg.Segmentation() - fiber_groups, _ = segmentation.segment( - bundles, + fiber_groups, _ = recognize( tg, + nib.load(hardi_fdata), mapping, - nib.load(hardi_fdata)) + bundles, + reg_template) # Already using a subsampled tck # the Right Corticospinal has two streamlines and @@ -340,25 +326,20 @@ def test_segment_sampled_streamlines(): nb_streamlines = int(len(tg)*0.8) # sample and segment streamlines - sampled_segmentation = seg.Segmentation( - nb_streamlines=nb_streamlines - ) - - sampled_fiber_groups, _ = sampled_segmentation.segment( - bundles, + sampled_fiber_groups, _ = recognize( tg, + nib.load(hardi_fdata), mapping, - nib.load(hardi_fdata)) - - # sampled streamlines should equal the sample number - npt.assert_equal(len(sampled_segmentation.tg), nb_streamlines) + bundles, + reg_template, + nb_streamlines=nb_streamlines) # sampled streamlines should be subset of the original streamlines npt.assert_( np.all( np.isin( - sampled_segmentation.tg.streamlines._data, - tg.streamlines._data + sampled_fiber_groups['Right Corticospinal'].streamlines._data, + fiber_groups['Right Corticospinal'].streamlines._data ) ) ) diff --git a/AFQ/utils/bin.py b/AFQ/utils/bin.py index a2dc21e35..cbfcc3ad4 100644 --- a/AFQ/utils/bin.py +++ b/AFQ/utils/bin.py @@ -165,15 +165,16 @@ def dict_to_json(dictionary): def func_dict_to_arg_dict(func_dict=None, logger=None): if func_dict is None: - import AFQ.segmentation as seg + from AFQ.bundle_rec.recognize import recognize + from AFQ.bundle_rec.cleaning import clean_bundle import AFQ.tractography.tractography as aft from AFQ.api.group import GroupAFQ func_dict = { "BIDS": GroupAFQ.__init__, "Tractography": aft.track, - "Segmentation": seg.Segmentation.__init__, - "Cleaning": seg.clean_bundle} + "Segmentation": recognize, + "Cleaning": clean_bundle} arg_dict = {} for name, func in func_dict.items(): @@ -411,13 +412,14 @@ def generate_json(json_folder, overwrite=False, } ] }""" - import AFQ.segmentation as seg + from AFQ.bundle_rec.recognize import recognize + from AFQ.bundle_rec.cleaning import clean_bundle import AFQ.tractography.tractography as aft func_dict = { "Tractography": aft.track, - "Segmentation": seg.Segmentation.__init__, - "Cleaning": seg.clean_bundle} + "Segmentation": recognize, + "Cleaning": clean_bundle} arg_dict = func_dict_to_arg_dict(func_dict, logger=logger) diff --git a/docs/source/_progressbars.py b/docs/source/_progressbars.py index 308e1bc8f..5ad5786ff 100644 --- a/docs/source/_progressbars.py +++ b/docs/source/_progressbars.py @@ -22,11 +22,9 @@ def reset_progressbars(gallery_conf, fname): # disable tqdm import AFQ._fixes as fixes - import AFQ.segmentation as seg import AFQ.viz.utils as utils fixes.tqdm = _no_tqdm - seg.tqdm = _no_tqdm utils.tqdm = _no_tqdm # disable update_progressbar diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 9f62f95b7..44c0cd45f 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -29,6 +29,15 @@ key-value pairs:: same length as 'exclude'. optional. - 'mahal': Dict describing the parameters for cleaning. By default, we use the default behavior of the seg.clean_bundle function. + - 'recobundles': Dict which should contain an 'sl' key and 'centroid' + key. The 'sl' key should be the reference streamline and the 'centroid' + key should be the centroid threshold for Recobundles. + - 'qb_thresh': Float which is the threshold for Quickbundles cleaning. + - 'primary_axis': int which is the index of the primary axis the + bundle should travel in. + - 'primary_axis_percentage': Used with primary_axis, defines what fraction + of a streamlines movement should be in the primary axis. + - 'length': dicitonary containing 'min_len' and 'max_len' For an example, see "Plotting the Optic Radiations" in :ref:`examples`. @@ -47,15 +56,12 @@ of the steps: 7. Include 8. Curvature 9. Exclude - 10. Quickbundles Cleaning - 11. Mahalanobis Cleaning + 10. Recobundles + 11. Quickbundles Cleaning + 12. Mahalanobis Cleaning If a streamline passes all steps for a bundle, it is included in that bundle. -If a streamline passess all steps for multiple bundles, one of three things -happens. By default, the probability maps are used as the tie breaker. If -`roi_dist_tie_break` is set in `segmentation_params`, then distance to the ROI -is used instead. If both of these still results in a tie (or no probability -map is provided for a given bundle), then the tie goes to whichever bundle is -first in the bundle dictionary. +If a streamline passess all steps for multiple bundles, then a warning is +thrown and the tie goes to whichever bundle is first in the bundle dictionary. If, for debugging purposes, you want to save out the streamlines diff --git a/examples/howto_examples/plot_recobundles.py b/examples/howto_examples/plot_recobundles.py index 411ba1654..bfd9c042f 100644 --- a/examples/howto_examples/plot_recobundles.py +++ b/examples/howto_examples/plot_recobundles.py @@ -16,6 +16,7 @@ import os.path as op import AFQ.data.fetch as afd from AFQ.api.group import GroupAFQ +import AFQ.api.bundle_dict as abd afd.organize_stanford_data(clear_previous_afq="track") @@ -30,7 +31,7 @@ # We also refer to bundle recognition as the "segmentation" of the tractogram. # Parameters of this process are set through a dictionary input to the # `segmentation_params` argument of the GroupAFQ object. In this case, we -# set the `seg_algo` to "reco", which tells pyAFQ to use the RecoBundles +# use `abd.reco_bd(16)`, which tells pyAFQ to use the RecoBundles # algorithm for bundle recognition. myafq = GroupAFQ( @@ -38,7 +39,7 @@ 'recobundles'), bids_path=op.join(afd.afq_home, 'stanford_hardi'), # Set the algorithm to use RecoBundles for bundle recognition: - segmentation_params={'seg_algo': 'reco'}, + bundle_info=abd.reco_bd(16), preproc_pipeline='vistasoft', tracking_params=tracking_params, viz_backend_spec='plotly_no_gif') From d1e0ccfef832a4738adfd9c74e701287c8abdaac Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 29 Apr 2024 14:59:19 -0700 Subject: [PATCH 02/11] BFs --- AFQ/api/bundle_dict.py | 120 ++++++++++++++-------------------- AFQ/api/group.py | 67 ++++++++++--------- AFQ/bundle_rec/cleaning.py | 2 +- AFQ/bundle_rec/criteria.py | 3 +- AFQ/bundle_rec/recognize.py | 15 +++++ AFQ/data/fetch.py | 11 ---- AFQ/tasks/segmentation.py | 60 ++++++++--------- AFQ/tests/test_bundle_dict.py | 3 - AFQ/utils/streamlines.py | 9 ++- 9 files changed, 134 insertions(+), 156 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index e4b8b3d09..e54679ac0 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -494,7 +494,6 @@ def baby_bd(): "primary_axis_percentage": 40, "cross_midline": False, "mahal": {"distance_threshold": 4}}}, - seg_algo="afq", resample_to=afd.read_pediatric_templates()[ 'UNCNeo-withCerebellum-for-babyAFQ']) @@ -567,7 +566,7 @@ def reco_bd(n_bundles): Selects between 16 or 80 bundle atlas """ templates = afd.read_hcp_atlas(n_bundles, as_file=True) - return BundleDict(templates, seg_algo="reco") + return BundleDict(templates) def cerebellar_bd(): @@ -696,12 +695,6 @@ class BundleDict(MutableMapping): custom bundles. See `Defining Custom Bundle Dictionaries` in the `usage` section of pyAFQ's documentation for details. - seg_algo: One of {"afq", "reco"} - The bundle segmentation algorithm to use. - "afq" : Use waypoint ROIs + probability maps, as described - in [Yeatman2012]_ - "reco" : Use Recobundles [Garyfallidis2017]_ - resample_to : Nifti1Image or bool, optional If there are bundles in bundle_info with the 'space' attribute set to 'template', or with no 'space' attribute, @@ -762,7 +755,6 @@ class BundleDict(MutableMapping): def __init__(self, bundle_info, - seg_algo="afq", resample_to=None, resample_subject_to=False, keep_in_memory=False): @@ -770,7 +762,6 @@ def __init__(self, raise TypeError(( f"bundle_info must be a dict," f" currently a {type(bundle_info)}")) - self.seg_algo = seg_algo.lower() if resample_to is None: resample_to = afd.read_mni_template() self.resample_to = resample_to @@ -784,31 +775,30 @@ def __init__(self, self.__setitem__(key, item) self.logger = logging.getLogger('AFQ') - if self.seg_algo == "afq": - if "Forceps Major" in self.bundle_names\ - and "Callosum Occipital" in self.bundle_names: - self.logger.info(( - "Forceps Major and Callosum Occipital bundles" - " are co-located, and AFQ" - " assigns each streamline to only one bundle." - " Only Callosum Occipital will be used.")) - self.bundle_names.remove("Forceps Major") - if "Forceps Minor" in self.bundle_names\ - and "Callosum Orbital" in self.bundle_names: - self.logger.info(( - "Forceps Minor and Callosum Orbital bundles" - " are co-located, and AFQ" - " assigns each streamline to only one bundle." - " Only Callosum Orbital will be used.")) - self.bundle_names.remove("Forceps Minor") - if "Forceps Minor" in self.bundle_names\ - and "Callosum Anterior Frontal" in self.bundle_names: - self.logger.info(( - "Forceps Minor and Callosum Anterior Frontal bundles" - " are co-located, and AFQ" - " assigns each streamline to only one bundle." - " Only Callosum Anterior Frontal will be used.")) - self.bundle_names.remove("Forceps Minor") + if "Forceps Major" in self.bundle_names\ + and "Callosum Occipital" in self.bundle_names: + self.logger.info(( + "Forceps Major and Callosum Occipital bundles" + " are co-located, and AFQ" + " assigns each streamline to only one bundle." + " Only Callosum Occipital will be used.")) + self.bundle_names.remove("Forceps Major") + if "Forceps Minor" in self.bundle_names\ + and "Callosum Orbital" in self.bundle_names: + self.logger.info(( + "Forceps Minor and Callosum Orbital bundles" + " are co-located, and AFQ" + " assigns each streamline to only one bundle." + " Only Callosum Orbital will be used.")) + self.bundle_names.remove("Forceps Minor") + if "Forceps Minor" in self.bundle_names\ + and "Callosum Anterior Frontal" in self.bundle_names: + self.logger.info(( + "Forceps Minor and Callosum Anterior Frontal bundles" + " are co-located, and AFQ" + " assigns each streamline to only one bundle." + " Only Callosum Anterior Frontal will be used.")) + self.bundle_names.remove("Forceps Minor") def update_max_includes(self, new_max): if new_max > self.max_includes: @@ -832,11 +822,11 @@ def _cond_load(self, roi_or_sl, resample_to): Load ROI or streamline if not already loaded """ if isinstance(roi_or_sl, str): - if self.seg_algo == "afq": + if ".nii" in roi_or_sl: return afd.read_resample_roi( roi_or_sl, resample_to=resample_to) - elif self.seg_algo.startswith("reco"): + else: return load_tractogram( roi_or_sl, 'same', @@ -860,7 +850,6 @@ def __getitem__(self, key): return self.__class__( new_bd, - seg_algo=self.seg_algo, resample_to=self.resample_to, resample_subject_to=self.resample_subject_to, keep_in_memory=self.keep_in_memory) @@ -925,7 +914,6 @@ def copy(self): """ return self.__class__( self._dict.copy(), - seg_algo=self.seg_algo, resample_to=self.resample_to, resample_subject_to=self.resample_subject_to, keep_in_memory=self.keep_in_memory) @@ -958,27 +946,25 @@ def apply_to_rois(self, b_name, func, *args, the roi type and values are the transformed ROIs. """ return_vals = {} - if self.seg_algo == "afq": - for roi_type in ["include", "exclude", "start", "end", "prob_map"]: - if roi_type in self._dict[b_name]: - if roi_type in ["start", "end", "prob_map"]: - return_vals[roi_type] = func( - self._dict[b_name][roi_type], *args, **kwargs) - else: - changed_rois = [] - for _roi in self._dict[b_name][roi_type]: - changed_rois.append(func( - _roi, *args, **kwargs)) - return_vals[roi_type] = changed_rois - elif self.seg_algo.startswith("reco"): - if b_name == "whole_brain": - return_vals = func( - self._dict[b_name], *args, **kwargs) - else: - for sl_type in ["sl", "centroid"]: - return_vals[sl_type] = func( - self._dict[b_name][sl_type], - *args, **kwargs) + for roi_type in [ + "include", "exclude", + "start", "end", "prob_map"]: + if roi_type in self._dict[b_name]: + if roi_type in ["start", "end", "prob_map"]: + return_vals[roi_type] = func( + self._dict[b_name][roi_type], *args, **kwargs) + else: + changed_rois = [] + for _roi in self._dict[b_name][roi_type]: + changed_rois.append(func( + _roi, *args, **kwargs)) + return_vals[roi_type] = changed_rois + if "recobundles" in self._dict[b_name]: + return_vals["recobundles"] = {} + for sl_type in ["sl", "centroid"]: + return_vals["recobundles"][sl_type] = func( + self._dict[b_name]["recobundles"][sl_type], + *args, **kwargs) if not dry_run: for roi_type, roi in return_vals.items(): self._dict[b_name][roi_type] = roi @@ -995,13 +981,10 @@ def _cond_load_bundle(self, b_name, dry_run=False): b_name : str Name of the bundle to be resampled. """ - if self.seg_algo == "afq": - if self.is_bundle_in_template(b_name): - resample_to = self.resample_to - else: - resample_to = self.resample_subject_to + if self.is_bundle_in_template(b_name): + resample_to = self.resample_to else: - resample_to = None + resample_to = self.resample_subject_to return self.apply_to_rois( b_name, self._cond_load, @@ -1087,10 +1070,6 @@ def transform_rois(self, bundle_name, mapping, new_affine, return transformed_rois def __add__(self, other): - if self.seg_algo != other.seg_algo: - raise ValueError(( - "Adding BundleDicts where seg_algo do not match." - f"seg_algo's are {self.seg_algo} and {other.seg_algo}")) for resample in ["resample_to", "resample_subject_to"]: if not getattr(self, resample)\ or not getattr(other, resample)\ @@ -1120,7 +1099,6 @@ def __add__(self, other): f"{getattr(other, resample).header['dim']}")) return self.__class__( {**self._dict, **other._dict}, - self.seg_algo, self.resample_to, self.resample_subject_to, self.keep_in_memory) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 76a297a60..bbcc046f1 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -441,40 +441,39 @@ def load_next_subject(): sls_dict = {} load_next_subject() # load first subject for b in bundle_dict.keys(): - if b != "whole_brain": - for i in range(len(self.valid_sub_list)): - seg_sft, mapping = subses_info[i] - idx = seg_sft.bundle_idxs[b] - # use the first subses that works - # otherwise try each successive subses - if len(idx) == 0: - # break if we run out of subses - if i + 1 >= len(self.valid_sub_list): - break - # load subses if not already loaded - if i + 1 >= len(subses_info): - load_next_subject() - continue - if len(idx) > 100: - idx = np.random.choice( - idx, size=100, replace=False) - these_sls = seg_sft.sft.streamlines[idx] - these_sls = dps.set_number_of_points(these_sls, 100) - tg = StatefulTractogram( - these_sls, - seg_sft.sft, - Space.RASMM) - delta = dts.values_from_volume( - mapping.forward, - tg.streamlines, np.eye(4)) - moved_sl = dts.Streamlines( - [d + s for d, s in zip(delta, tg.streamlines)]) - moved_sl = np.asarray(moved_sl) - median_sl = np.median(moved_sl, axis=0) - sls_dict[b] = {"coreFiber": median_sl.tolist()} - for ii, sl_idx in enumerate(idx): - sls_dict[b][str(sl_idx)] = moved_sl[ii].tolist() - break + for i in range(len(self.valid_sub_list)): + seg_sft, mapping = subses_info[i] + idx = seg_sft.bundle_idxs[b] + # use the first subses that works + # otherwise try each successive subses + if len(idx) == 0: + # break if we run out of subses + if i + 1 >= len(self.valid_sub_list): + break + # load subses if not already loaded + if i + 1 >= len(subses_info): + load_next_subject() + continue + if len(idx) > 100: + idx = np.random.choice( + idx, size=100, replace=False) + these_sls = seg_sft.sft.streamlines[idx] + these_sls = dps.set_number_of_points(these_sls, 100) + tg = StatefulTractogram( + these_sls, + seg_sft.sft, + Space.RASMM) + delta = dts.values_from_volume( + mapping.forward, + tg.streamlines, np.eye(4)) + moved_sl = dts.Streamlines( + [d + s for d, s in zip(delta, tg.streamlines)]) + moved_sl = np.asarray(moved_sl) + median_sl = np.median(moved_sl, axis=0) + sls_dict[b] = {"coreFiber": median_sl.tolist()} + for ii, sl_idx in enumerate(idx): + sls_dict[b][str(sl_idx)] = moved_sl[ii].tolist() + break with open(sls_json_fname, 'w') as fp: json.dump(sls_dict, fp) diff --git a/AFQ/bundle_rec/cleaning.py b/AFQ/bundle_rec/cleaning.py index cbc02948f..972fce70c 100644 --- a/AFQ/bundle_rec/cleaning.py +++ b/AFQ/bundle_rec/cleaning.py @@ -4,10 +4,10 @@ from scipy.stats import zscore import dipy.tracking.streamline as dts -from dipy.stats.analysis import gaussian_weights from dipy.io.stateful_tractogram import StatefulTractogram, Space import AFQ.bundle_rec.utils as abu +from AFQ._fixes import gaussian_weights logger = logging.getLogger('AFQ') diff --git a/AFQ/bundle_rec/criteria.py b/AFQ/bundle_rec/criteria.py index 15b544ff8..508155fab 100644 --- a/AFQ/bundle_rec/criteria.py +++ b/AFQ/bundle_rec/criteria.py @@ -137,7 +137,8 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, include_roi_tols.append(( inc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"])**2) else: - include_roi_tols = [preproc_imap["tol"]**2] * len(bundle_def["include"]) + include_roi_tols = [preproc_imap["tol"]**2] * len( + bundle_def["include"]) include_rois = [] for include_roi in bundle_def["include"]: diff --git a/AFQ/bundle_rec/recognize.py b/AFQ/bundle_rec/recognize.py index 9e1771343..96d6da104 100644 --- a/AFQ/bundle_rec/recognize.py +++ b/AFQ/bundle_rec/recognize.py @@ -43,6 +43,16 @@ def recognize( Parameters ---------- + tg : str, StatefulTractogram + Tractogram to segment. + img : str, nib.Nifti1Image + Image for reference. + mapping : MappingDefinition + Mapping from subject to template. + bundle_dict : dict or AFQ.api.BundleDict + Dictionary of bundles to segment. + reg_template : str, nib.Nifti1Image + Template image for registration. nb_points : int, boolean Resample streamlines to nb_points number of points. If False, no resampling is done. Default: False @@ -61,8 +71,10 @@ def recognize( Default: {"engine": "serial"} rb_recognize_params : dict RecoBundles parameters for the recognize function. + Default: dict(model_clust_thr=1.25, reduction_thr=25, pruning_thr=12) refine_reco : bool Whether to refine the RecoBundles segmentation. + Default: False prob_threshold : float. Using AFQ Algorithm. Initial cleaning of fiber groups is done using probability maps @@ -81,6 +93,7 @@ def recognize( modification to your bundle_dict: bundle_dict["Right Arcuate"]["inc_addtol"] = [3, 3] Additional tolerances can also be negative. + Default: None. rng : RandomState or int If None, creates RandomState. If int, creates RandomState with seed rng. @@ -89,12 +102,14 @@ def recognize( return_idx : bool Whether to return the indices in the original streamlines as part of the output of segmentation. + Default: False. filter_by_endpoints: bool Whether to filter the bundles based on their endpoints. Default: True. dist_to_atlas : float If filter_by_endpoints is True, this is the required distance from the endpoints to the atlas ROIs. + Default: 4 save_intermediates : str, optional The full path to a folder into which intermediate products are saved. Default: None, means no saving of intermediates. diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 15d30c91c..b87771d86 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1259,17 +1259,6 @@ def read_hcp_atlas(n_bundles=16, as_file=False): _, folder = fetch_hcp_atlas_80_bundles() atlas_folder = "Atlas_80_Bundles" - whole_brain = op.join( - folder, - atlas_folder, - 'whole_brain', - 'whole_brain_MNI.trk') - if not as_file: - whole_brain = load_tractogram( - whole_brain, - 'same', bbox_valid_check=False).streamlines - - bundle_dict['whole_brain'] = whole_brain bundle_files = glob( op.join( folder, diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 9b6147cff..6b979dd57 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -15,6 +15,7 @@ import AFQ.utils.streamlines as aus from AFQ.tasks.utils import get_default_args import AFQ.utils.volume as auv +from AFQ._fixes import gaussian_weights try: from trx.io import load as load_trx @@ -25,8 +26,8 @@ has_trx = False from dipy.io.streamline import load_tractogram, save_tractogram -from dipy.io.stateful_tractogram import Space, StatefulTractogram -from dipy.stats.analysis import afq_profile, gaussian_weights +from dipy.io.stateful_tractogram import Space +from dipy.stats.analysis import afq_profile from dipy.tracking.streamline import set_number_of_points, values_from_volume @@ -127,35 +128,34 @@ def export_bundles(base_fname, output_dir, os.makedirs(bundles_dir, exist_ok=True) seg_sft = aus.SegmentedSFT.fromfile(bundles) for bundle in seg_sft.bundle_names: - if bundle != "whole_brain": - fname = op.split( - get_fname( - base_fname, - f'_desc-{str_to_desc(bundle)}' - f'_tractography{extension}', - tracking_params=tracking_params, - segmentation_params=segmentation_params)) - fname = op.join(bundles_dir, fname[1]) - bundle_sft = seg_sft.get_bundle(bundle) - if len(bundle_sft) > 0: - logger.info(f"Saving {fname}") - if is_trx: - seg_sft.sft.dtype_dict = { - 'positions': np.float16, - 'offsets': np.uint32} - trxfile = TrxFile.from_sft(bundle_sft) - save_trx(trxfile, fname) - else: - save_tractogram( - bundle_sft, fname, - bbox_valid_check=False) + fname = op.split( + get_fname( + base_fname, + f'_desc-{str_to_desc(bundle)}' + f'_tractography{extension}', + tracking_params=tracking_params, + segmentation_params=segmentation_params)) + fname = op.join(bundles_dir, fname[1]) + bundle_sft = seg_sft.get_bundle(bundle) + if len(bundle_sft) > 0: + logger.info(f"Saving {fname}") + if is_trx: + seg_sft.sft.dtype_dict = { + 'positions': np.float16, + 'offsets': np.uint32} + trxfile = TrxFile.from_sft(bundle_sft) + save_trx(trxfile, fname) else: - logger.info(f"No bundle to save for {bundle}") - meta = dict( - source=bundles, - params=seg_sft.get_bundle_param_info(bundle)) - meta_fname = drop_extension(fname) + '.json' - write_json(meta_fname, meta) + save_tractogram( + bundle_sft, fname, + bbox_valid_check=False) + else: + logger.info(f"No bundle to save for {bundle}") + meta = dict( + source=bundles, + params=seg_sft.get_bundle_param_info(bundle)) + meta_fname = drop_extension(fname) + '.json' + write_json(meta_fname, meta) return bundles_dir diff --git a/AFQ/tests/test_bundle_dict.py b/AFQ/tests/test_bundle_dict.py index 42642e723..28e85ce0c 100644 --- a/AFQ/tests/test_bundle_dict.py +++ b/AFQ/tests/test_bundle_dict.py @@ -56,6 +56,3 @@ def test_BundleDict(): afq_bundles = abd.reco_bd(80)["VOF_L", "VOF_R"] assert len(afq_bundles) == 2 - - afq_bundles = abd.reco_bd(80)[["whole_brain"]] - assert len(afq_bundles) == 1 diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 0cb910b88..cee59135d 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -97,11 +97,10 @@ def fromfile(cls, trk_or_trx_file, reference="same", sidecar_file=None): bundles = {} if "bundle_ids" in sidecar_info: for b_name, b_id in sidecar_info["bundle_ids"].items(): - if not b_name == "whole_brain": - idx = np.where( - sft.data_per_streamline['bundle'] == b_id)[0] - bundles[b_name] = StatefulTractogram( - sft.streamlines[idx], reference, Space.RASMM) + idx = np.where( + sft.data_per_streamline['bundle'] == b_id)[0] + bundles[b_name] = StatefulTractogram( + sft.streamlines[idx], reference, Space.RASMM) else: bundles["whole_brain"] = sft From e164341c3ef5629672d9539c035ef43e917307af Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 29 Apr 2024 16:17:55 -0700 Subject: [PATCH 03/11] BF --- AFQ/_fixes.py | 18 +++++++++------- AFQ/api/bundle_dict.py | 44 ++++++++++++++++++++++++++------------ AFQ/bundle_rec/criteria.py | 3 ++- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 038581fad..d4383431d 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -195,19 +195,21 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, else: resample = True if resample: - sls = set_number_of_points(bundle, n_points) + sls = np.asarray(set_number_of_points(bundle, n_points)) else: sls = bundle - # If there's only one fiber here, it gets the entire weighting: - if len(bundle) == 1: + n_sls, n_nodes, _ = sls.shape + + # Only do this with sufficient streamlines: + if n_sls < 20: + weights = np.ones((n_sls, n_nodes)) if return_mahalnobis: - return np.array([np.nan]) + return np.full((n_sls, n_nodes), np.nan) else: - return np.array([1]) - - n_sls, n_nodes, n_dim = sls.shape - weights = np.zeros((n_sls, n_nodes)) + return weights / np.sum(weights, 0) + else: + weights = np.zeros((n_sls, n_nodes)) diff = stat(sls, axis=0) - sls for i in range(n_nodes): # This should come back as a 3D covariance matrix with the spatial diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index e54679ac0..794de8a81 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -919,7 +919,7 @@ def copy(self): keep_in_memory=self.keep_in_memory) def apply_to_rois(self, b_name, func, *args, - dry_run=False, + dry_run=False, apply_to_recobundles=False, **kwargs): """ Applies some transformation to all ROIs (include, exclude, end, start) @@ -935,6 +935,11 @@ def apply_to_rois(self, b_name, func, *args, dry_run : bool Whether to actually apply changes returned by `func` to the ROIs. If has_return is False, dry_run is not used. + apply_to_recobundles : bool, optional + Whether to apply the transformation to recobundles + TRKs as well. + Default: False + *args : Additional arguments for func **kwargs @@ -959,7 +964,7 @@ def apply_to_rois(self, b_name, func, *args, changed_rois.append(func( _roi, *args, **kwargs)) return_vals[roi_type] = changed_rois - if "recobundles" in self._dict[b_name]: + if apply_to_recobundles and "recobundles" in self._dict[b_name]: return_vals["recobundles"] = {} for sl_type in ["sl", "centroid"]: return_vals["recobundles"][sl_type] = func( @@ -989,23 +994,28 @@ def _cond_load_bundle(self, b_name, dry_run=False): b_name, self._cond_load, resample_to, - dry_run=dry_run) + dry_run=dry_run, + apply_to_recobundles=True) def is_bundle_in_template(self, bundle_name): return "space" not in self._dict[bundle_name]\ or self._dict[bundle_name]["space"] == "template" - def _roi_transform_helper(self, roi, mapping, new_affine, bundle_name): - roi = afd.read_resample_roi(roi, self.resample_to) - warped_img = auv.transform_inverse_roi( - roi.get_fdata(), - mapping, - bundle_name=bundle_name) - warped_img = nib.Nifti1Image(warped_img, new_affine) - return warped_img + def _roi_transform_helper(self, roi_or_sl, mapping, + new_affine, bundle_name): + roi_or_sl = self._cond_load(roi_or_sl, self.resample_to) + if isinstance(roi_or_sl, nib.Nifti1Image): + warped_img = auv.transform_inverse_roi( + roi_or_sl.get_fdata(), + mapping, + bundle_name=bundle_name) + warped_img = nib.Nifti1Image(warped_img, new_affine) + return warped_img + else: + return roi_or_sl def transform_rois(self, bundle_name, mapping, new_affine, - base_fname=None): + base_fname=None, apply_to_recobundles=False): """ Get the bundle definition with transformed ROIs for a given bundle into a @@ -1026,6 +1036,10 @@ def transform_rois(self, bundle_name, mapping, new_affine, Base file path to save ROIs too. Additional BIDS descriptors will be added to this file path. If None, do not save the ROIs. + apply_to_recobundles : bool, optional + Whether to apply the transformation to recobundles + TRKs as well. + Default: False Returns ------- @@ -1041,13 +1055,15 @@ def transform_rois(self, bundle_name, mapping, new_affine, mapping, new_affine, bundle_name, - dry_run=True) + dry_run=True, + apply_to_recobundles=apply_to_recobundles) else: transformed_rois = self.apply_to_rois( bundle_name, self._cond_load, self.resample_subject_to, - dry_run=True) + dry_run=True, + apply_to_recobundles=apply_to_recobundles) if base_fname is not None: fnames = [] diff --git a/AFQ/bundle_rec/criteria.py b/AFQ/bundle_rec/criteria.py index 508155fab..3598fca34 100644 --- a/AFQ/bundle_rec/criteria.py +++ b/AFQ/bundle_rec/criteria.py @@ -333,7 +333,8 @@ def run_bundle_rec_plan( bundle_def.update(bundle_dict.transform_rois( bundle_name, mapping, - img.affine)) + img.affine, + apply_to_recobundles=True)) logger.info(f"Time to prep ROIs: {time()-start_time}s") b_sls = abu.SlsBeingRecognized( From 6eaf61aa9958e3b78391be85cf4d3c1ff87d9a6d Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 30 Apr 2024 07:10:59 -0700 Subject: [PATCH 04/11] fix gaussian weights for low n --- AFQ/_fixes.py | 28 +++++++++++-------------- AFQ/bundle_rec/cleaning.py | 3 ++- AFQ/tests/test_fixes.py | 42 +++++++++++++++++++++++--------------- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index d4383431d..cf8c03e88 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -153,7 +153,7 @@ def tensor_odf(evals, evecs, sphere, num_batches=100): def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, - stat=np.mean): + stat=np.mean, resample=True): """ Calculate weights for each streamline/node in a bundle, based on a Mahalanobis distance from the core the bundle, at that node (mean, per @@ -173,7 +173,10 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, The statistic used to calculate the central tendency of streamlines in each node. Can be one of {`np.mean`, `np.median`} or other functions that have similar API. Default: `np.mean` - + resample : bool, optional + Whether its necessary to resample the streamlines to the same number + of points. Only set to False if they are already resampled. + Default: True. Returns ------- w : array of shape (n_streamlines, n_points) @@ -182,27 +185,18 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, coordinates at that node position across streamlines. """ - - # Resample to same length for each streamline - # if necessary - resample = False - if isinstance(bundle, np.ndarray): - if len(bundle.shape) > 2: - if bundle.shape[1] != n_points: - sls = bundle.tolist() - sls = [np.asarray(item) for item in sls] - resample = True - else: - resample = True if resample: + if isinstance(bundle, np.ndarray): + bundle = bundle.tolist() + if isinstance(bundle, list): + bundle = [np.asarray(item) for item in bundle] sls = np.asarray(set_number_of_points(bundle, n_points)) else: sls = bundle n_sls, n_nodes, _ = sls.shape - # Only do this with sufficient streamlines: - if n_sls < 20: + if n_sls < 15: # Cov^-1 unstable under this amount weights = np.ones((n_sls, n_nodes)) if return_mahalnobis: return np.full((n_sls, n_nodes), np.nan) @@ -216,6 +210,8 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, # variance covariance of this node across the different streamlines, # reorganized as an upper diagonal matrix for expected Mahalanobis cov = np.cov(sls[:, i, :].T, ddof=0) + while np.any(np.linalg.eigvals(cov) < 0): + cov += np.eye(cov.shape[0]) * 1e-12 # calculate Mahalanobis for node in every fiber if np.any(cov > 0): diff --git a/AFQ/bundle_rec/cleaning.py b/AFQ/bundle_rec/cleaning.py index 972fce70c..498134293 100644 --- a/AFQ/bundle_rec/cleaning.py +++ b/AFQ/bundle_rec/cleaning.py @@ -136,7 +136,8 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( fgarray, return_mahalnobis=True, - n_points=n_points, stat=stat) + n_points=n_points, stat=stat, + resample=False) logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}") logger.debug(f"Shape of m_dist: {m_dist.shape}") logger.debug(f"Maximum m_dist: {np.max(m_dist)}") diff --git a/AFQ/tests/test_fixes.py b/AFQ/tests/test_fixes.py index ace7c9f28..8aaf3679e 100644 --- a/AFQ/tests/test_fixes.py +++ b/AFQ/tests/test_fixes.py @@ -10,10 +10,8 @@ import dipy.core.gradients as dpg from dipy.data import default_sphere from dipy.reconst.gqi import GeneralizedQSamplingModel -from dipy.tracking.streamline import set_number_of_points from AFQ._fixes import gaussian_weights as gaussian_weights_fast -from dipy.stats.analysis import gaussian_weights from AFQ.utils.testing import make_dki_data from AFQ._fixes import gwi_odf @@ -41,22 +39,34 @@ def test_GQI_fix(): def test_mahal_fix(): sls = [ - [ - [8.0, 53, 39], [8, 50, 39], [8, 45, 39], - [30, 41, 61], [28, 61, 38]], - [ - [8, 53, 39], [8, 50, 39], [8, 45, 39], - [30, 41, 62], [20, 44, 34]], - [ - [8, 53, 39], [8, 50, 39], [8, 45, 39], - [50, 67, 88], [10, 10, 20]] - ] + [[8.0, 53, 39], [8, 50, 39], [8, 45, 39], [30, 41, 61], [28, 61, 38]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [30, 41, 62], [20, 44, 34]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [50, 67, 88], [10, 10, 20]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [35, 43, 65], [25, 55, 35]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [40, 50, 70], [15, 15, 25]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [45, 54, 75], [12, 22, 32]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [32, 48, 68], [28, 58, 40]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [38, 52, 72], [18, 38, 28]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [34, 44, 64], [21, 41, 31]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [36, 46, 66], [23, 53, 33]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [37, 47, 67], [24, 54, 34]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [39, 49, 69], [19, 39, 29]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [33, 53, 73], [22, 42, 32]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [31, 51, 71], [26, 56, 36]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [29, 59, 79], [27, 57, 37]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [28, 58, 78], [17, 47, 27]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [27, 57, 77], [16, 36, 26]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [26, 56, 76], [14, 24, 34]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [25, 55, 75], [13, 23, 33]], + [[8, 53, 39], [8, 50, 39], [8, 45, 39], [24, 54, 74], [11, 21, 31]] + ] sls_array = np.asarray(sls).astype(float) results = np.asarray([ - [0. , 0. , 0. , 1.185854, 2.14735], - [0. , 0. , 0. , 1.185854, 1.556795], - [0. , 0. , 0. , 1.274755, 2.23296]]) + [0. , 0. , 0. , 1.718654, 1.550252], + [0. , 0. , 0. , 2.202227, 0.7881 ], + [0. , 0. , 0. , 3.415999, 2.689814]]) npt.assert_array_almost_equal( gaussian_weights_fast( sls_array, n_points=5, - return_mahalnobis=True, stat=np.mean), results) + return_mahalnobis=True, + stat=np.mean, resample=False)[:3], results) From 37915b55a8fe6e5260ba5bcfd7a6e618e140934e Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 30 Apr 2024 08:27:01 -0700 Subject: [PATCH 05/11] bf --- AFQ/api/bundle_dict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 794de8a81..99f0cbea3 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -683,7 +683,6 @@ def __setitem__(self, key, value): "in the BundleDict.")) -# TODO: refactor the recobundles parts of this class BundleDict(MutableMapping): """ Create a bundle dictionary, needed for the segmentation. @@ -831,6 +830,10 @@ def _cond_load(self, roi_or_sl, resample_to): roi_or_sl, 'same', bbox_valid_check=False).streamlines + elif isinstance(roi_or_sl, nib.Nifti1Image): + return afd.read_resample_roi( + roi_or_sl, + resample_to=resample_to) else: return roi_or_sl From 478a1dbca4ca652cf7b9419c8242362f8b8a2d14 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 6 May 2024 10:51:12 -0700 Subject: [PATCH 06/11] small readability changes --- AFQ/bundle_rec/preprocess.py | 4 +- AFQ/bundle_rec/recognize.py | 101 ++++++++++++++++++----------------- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/AFQ/bundle_rec/preprocess.py b/AFQ/bundle_rec/preprocess.py index 73628eddd..9beac5b55 100644 --- a/AFQ/bundle_rec/preprocess.py +++ b/AFQ/bundle_rec/preprocess.py @@ -12,7 +12,7 @@ @pimms.calc("tol", "dist_to_atlas", "vox_dim") -def incexc_tol(img, dist_to_waypoint, input_dist_to_atlas): +def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: R = img.affine[0:3, 0:3] @@ -64,7 +64,7 @@ def crosses(fgarray, img): # (i.e., for a whole tractogram) go here def get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas): preproc_plan = pimms.Plan( - incexc_tol=incexc_tol, + tolerance_mm_to_vox=tolerance_mm_to_vox, fgarray=fgarray, crosses=crosses) return preproc_plan( diff --git a/AFQ/bundle_rec/recognize.py b/AFQ/bundle_rec/recognize.py index 96d6da104..890bf305f 100644 --- a/AFQ/bundle_rec/recognize.py +++ b/AFQ/bundle_rec/recognize.py @@ -211,49 +211,6 @@ def recognize( bundle_decisions, np.ones((n_streamlines, 1))), axis=1) bundle_decisions = np.argmax(bundle_decisions, -1) - # Helper functions for formatting the results - def _return_empty(bundle_name): - """ - Helper function to return an empty dict under - some conditions. - """ - if return_idx: - fiber_groups[bundle_name] = {} - fiber_groups[bundle_name]['sl'] = StatefulTractogram( - [], img, Space.VOX) - fiber_groups[bundle_name]['idx'] = np.array([]) - else: - fiber_groups[bundle_name] = StatefulTractogram( - [], img, Space.VOX) - - def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip): - """ - Helper function to add a bundle to a fiber group. - """ - sl = abu.flip_sls( - sl, to_flip, - in_place=False) - - sl = StatefulTractogram( - sl, - img, - Space.VOX) - - if return_idx: - fiber_groups[b_name] = {} - fiber_groups[b_name]['sl'] = sl - fiber_groups[b_name]['idx'] = idx - else: - fiber_groups[b_name] = sl - - def _add_bundle_to_meta(bundle_name, b_def): - # remove keys that can never be serialized - for key in [ - 'include', 'exclude', 'prob_map', - 'start', 'end', 'curvature']: - b_def.pop(key, None) - meta[bundle_name] = b_def - # We do another round through, so that we can orient all the # streamlines within a bundle in the same orientation with respect to # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0 @@ -269,9 +226,9 @@ def _add_bundle_to_meta(bundle_name, b_def): if "bundlesection" in bundle_dict.get_b_info(bundle): for sb_name in bundle_dict.get_b_info(bundle)[ "bundlesection"]: - _return_empty(sb_name) + _return_empty(sb_name, return_idx, fiber_groups, img) else: - _return_empty(bundle) + _return_empty(bundle, return_idx, fiber_groups, img) continue # Use a list here, because ArraySequence doesn't support item @@ -296,10 +253,58 @@ def _add_bundle_to_meta(bundle_name, b_def): sb_include_cuts, in_place=False) _add_bundle_to_fiber_group( sb_name, bundlesection_select_sl, select_idx, - to_flip) + to_flip, return_idx, fiber_groups, img) _add_bundle_to_meta(sb_name, b_def) else: _add_bundle_to_fiber_group( - bundle, select_sl, select_idx, to_flip) - _add_bundle_to_meta(bundle, b_def) + bundle, select_sl, select_idx, to_flip, + return_idx, fiber_groups, img) + _add_bundle_to_meta(bundle, b_def, meta) return fiber_groups, meta + + +# Helper functions for formatting the results +def _return_empty(bundle_name, return_idx, fiber_groups, img): + """ + Helper function to return an empty dict under + some conditions. + """ + if return_idx: + fiber_groups[bundle_name] = {} + fiber_groups[bundle_name]['sl'] = StatefulTractogram( + [], img, Space.VOX) + fiber_groups[bundle_name]['idx'] = np.array([]) + else: + fiber_groups[bundle_name] = StatefulTractogram( + [], img, Space.VOX) + + +def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip, + return_idx, fiber_groups, img): + """ + Helper function to add a bundle to a fiber group. + """ + sl = abu.flip_sls( + sl, to_flip, + in_place=False) + + sl = StatefulTractogram( + sl, + img, + Space.VOX) + + if return_idx: + fiber_groups[b_name] = {} + fiber_groups[b_name]['sl'] = sl + fiber_groups[b_name]['idx'] = idx + else: + fiber_groups[b_name] = sl + + +def _add_bundle_to_meta(bundle_name, b_def, meta): + # remove keys that can never be serialized + for key in [ + 'include', 'exclude', 'prob_map', + 'start', 'end', 'curvature']: + b_def.pop(key, None) + meta[bundle_name] = b_def From fb9551a74f18e77b1b6583a82dc87276d719c2fe Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 16 May 2024 09:34:12 -0700 Subject: [PATCH 07/11] pep8 --- AFQ/tests/test_fixes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/AFQ/tests/test_fixes.py b/AFQ/tests/test_fixes.py index 8aaf3679e..7411f1e1f 100644 --- a/AFQ/tests/test_fixes.py +++ b/AFQ/tests/test_fixes.py @@ -62,9 +62,9 @@ def test_mahal_fix(): ] sls_array = np.asarray(sls).astype(float) results = np.asarray([ - [0. , 0. , 0. , 1.718654, 1.550252], - [0. , 0. , 0. , 2.202227, 0.7881 ], - [0. , 0. , 0. , 3.415999, 2.689814]]) + [0., 0., 0., 1.718654, 1.550252], + [0., 0., 0., 2.202227, 0.7881], + [0., 0., 0., 3.415999, 2.689814]]) npt.assert_array_almost_equal( gaussian_weights_fast( sls_array, n_points=5, From b9e0b57be6c7f758df75b878d60650ef20c84a87 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 16 May 2024 10:09:30 -0700 Subject: [PATCH 08/11] remove pimms --- AFQ/bundle_rec/criteria.py | 68 +++++++++++--------------------------- 1 file changed, 19 insertions(+), 49 deletions(-) diff --git a/AFQ/bundle_rec/criteria.py b/AFQ/bundle_rec/criteria.py index 3598fca34..d078e246a 100644 --- a/AFQ/bundle_rec/criteria.py +++ b/AFQ/bundle_rec/criteria.py @@ -1,4 +1,3 @@ -import pimms import numpy as np import logging from time import time @@ -26,8 +25,7 @@ logger = logging.getLogger('AFQ') -@pimms.calc("prob_map") -def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold): +def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): b_sls.initiate_selection("Prob. Map") # using entire fgarray here only because it is the first step fiber_probabilities = dts.values_from_volume( @@ -37,21 +35,17 @@ def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold): b_sls.select( fiber_probabilities > prob_threshold, "Prob. Map") - return 1 # You have to return something -@pimms.calc("cross_midline") -def cross_midline(b_sls, bundle_def, preproc_imap): +def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("Cross Mid.") accepted = preproc_imap["crosses"][b_sls.selected_fiber_idxs] if not bundle_def["cross_midline"]: accepted = np.invert(accepted) b_sls.select(accepted, "Cross Mid.") - return 1 -@pimms.calc("start") -def start(b_sls, bundle_def, preproc_imap): +def start(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("Startpoint") abr.clean_by_endpoints( b_sls.get_selected_sls(), @@ -70,11 +64,9 @@ def start(b_sls, bundle_def, preproc_imap): accept_idx = np.logical_xor( accepted_idx_flipped, accept_idx) b_sls.select(accept_idx, "Startpoint") - return 1 -@pimms.calc("end") -def end(b_sls, bundle_def, preproc_imap): +def end(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("endpoint") abr.clean_by_endpoints( b_sls.get_selected_sls(), @@ -93,11 +85,9 @@ def end(b_sls, bundle_def, preproc_imap): accept_idx = np.logical_xor( accepted_idx_flipped, accept_idx) b_sls.select(accept_idx, "endpoint") - return 1 -@pimms.calc("length") -def length(b_sls, bundle_def, preproc_imap): +def length(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("length") min_len = bundle_def["length"].get( "min_len", 0) / preproc_imap["vox_dim"] @@ -109,11 +99,9 @@ def length(b_sls, bundle_def, preproc_imap): if sl_len >= min_len and sl_len <= max_len: accept_idx[idx] = 1 b_sls.select(accept_idx, "length") - return 1 -@pimms.calc("primary_axis") -def primary_axis(b_sls, bundle_def): +def primary_axis(b_sls, bundle_def, **kwargs): b_sls.initiate_selection("orientation") accept_idx = abc.clean_by_orientation( b_sls.get_selected_sls(), @@ -121,12 +109,10 @@ def primary_axis(b_sls, bundle_def): bundle_def.get( "primary_axis_percentage", None)) b_sls.select(accept_idx, "orientation") - return 1 -@pimms.calc("include") def include(b_sls, bundle_def, preproc_imap, max_includes, - parallel_segmentation): + parallel_segmentation, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1\ and not b_sls.oriented_yet @@ -203,14 +189,13 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include") - return 1 -# Filters streamlines by how well they match -# a curve in orientation and shape but not scale - -@pimms.calc("curvature") -def curvature(b_sls, bundle_def, mapping, img, save_intermediates): +def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): + ''' + Filters streamlines by how well they match + a curve in orientation and shape but not scale + ''' accept_idx = b_sls.initiate_selection("curvature") if "sft" in bundle_def["curvature"]: ref_sl = bundle_def["curvature"]["sft"] @@ -237,11 +222,9 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates): if dist <= ref_curve_threshold: accept_idx[idx] = 1 b_sls.select(accept_idx, "curvature", cut=cut) - return 1 -@pimms.calc("exclude") -def exclude(b_sls, bundle_def, preproc_imap): +def exclude(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("exclude") if f'exc_addtol' in bundle_def: exclude_roi_tols = [] @@ -260,12 +243,10 @@ def exclude(b_sls, bundle_def, preproc_imap): sl, exclude_rois, exclude_roi_tols): accept_idx[sl_idx] = 1 b_sls.select(accept_idx, "exclude") - return 1 -@pimms.calc("recobundles") def recobundles(b_sls, mapping, bundle_def, reg_template, img, refine_reco, - save_intermediates, rng, rb_recognize_params): + save_intermediates, rng, rb_recognize_params, **kwargs): b_sls.initiate_selection("Recobundles") moved_sl = abu.move_streamlines( StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), @@ -286,11 +267,9 @@ def recobundles(b_sls, mapping, bundle_def, reg_template, img, refine_reco, standard_sl) b_sls.reorient(rec_labels[oriented_idx]) b_sls.select(rec_labels, "Recobundles") - return 1 -@pimms.calc("qb_thresh") -def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges): +def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): b_sls.initiate_selection("qb_thresh") cut = clip_edges or ("bundlesection" in bundle_def) qbx = QuickBundles( @@ -302,11 +281,9 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges): cleaned_idx = clusters[np.argmax( clusters.clusters_sizes())].indices b_sls.select(cleaned_idx, "qb_thresh", cut=cut) - return 1 -@pimms.calc("mahalanobis") -def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params): +def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params, **kwargs): b_sls.initiate_selection("Mahalanobis") clean_params = bundle_def.get("mahal", {}) clean_params = { @@ -318,7 +295,6 @@ def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params): b_sls.get_selected_sls(cut=cut, flip=True), **clean_params) b_sls.select(cleaned_idx, "Mahalanobis", cut=cut) - return 1 def run_bundle_rec_plan( @@ -343,12 +319,6 @@ def run_bundle_rec_plan( bundle_name, img, len(bundle_def.get("include", []))) - bundle_plan = {} - for criterion in bundle_criterion_order: - bundle_plan[criterion] = globals()[criterion] - bundle_plan["mahalanobis"] = mahalanobis - bundle_plan = pimms.plan(bundle_plan) - inputs = {} inputs["b_sls"] = b_sls inputs["preproc_imap"] = preproc_imap @@ -360,11 +330,11 @@ def run_bundle_rec_plan( for key, value in segmentation_params.items(): inputs[key] = value - bundle_imap = bundle_plan(inputs) for criterion in bundle_criterion_order: if b_sls and criterion in bundle_def: - bundle_imap[criterion] - bundle_imap["mahalanobis"] + inputs[criterion] = globals()[criterion](**inputs) + if b_sls: + mahalanobis(**inputs) if b_sls and not b_sls.oriented_yet: raise ValueError( From 27f3183ee32305fd1fb33d7a6e5203c2d4e5005f Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 16 May 2024 10:09:38 -0700 Subject: [PATCH 09/11] add ROI testing --- AFQ/tests/test_roi.py | 55 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 AFQ/tests/test_roi.py diff --git a/AFQ/tests/test_roi.py b/AFQ/tests/test_roi.py new file mode 100644 index 000000000..2ccb18e33 --- /dev/null +++ b/AFQ/tests/test_roi.py @@ -0,0 +1,55 @@ +import numpy as np +from scipy.spatial.distance import cdist + +from AFQ.bundle_rec.roi import ( + check_sls_with_inclusion, + check_sl_with_inclusion, + check_sl_with_exclusion) + + +streamline1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +streamline2 = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) +roi1 = np.array([[1, 2, 3], [4, 5, 6]]) +roi2 = np.array([[7, 8, 9], [10, 11, 12]]) +include_rois = [roi1, roi2] +exclude_rois = [roi1] +include_roi_tols = [10, 10] +exclude_roi_tols = [1] + + +def test_check_sls_with_inclusion(): + sls = [streamline1, streamline2] + result = list(check_sls_with_inclusion( + sls, include_rois, include_roi_tols)) + assert result[0][0] is True + assert np.allclose( + result[0][1][0], cdist(streamline1, roi1, 'sqeuclidean')) + assert np.allclose( + result[0][1][1], cdist(streamline1, roi2, 'sqeuclidean')) + assert result[1][0] is False + + +def test_check_sl_with_inclusion_pass(): + result, dists = check_sl_with_inclusion( + streamline1, include_rois, include_roi_tols) + assert result is True + assert len(dists) == 2 + + +def test_check_sl_with_inclusion_fail(): + result, dists = check_sl_with_inclusion( + streamline2, include_rois, include_roi_tols) + assert result is False + assert dists == [] + + +def test_check_sl_with_exclusion_pass(): + result = check_sl_with_exclusion( + streamline1, exclude_rois, exclude_roi_tols) + assert result is False + + +def test_check_sl_with_exclusion_fail(): + result = check_sl_with_exclusion( + streamline2, exclude_rois, exclude_roi_tols) + assert result is True From 3b9f05aefc16c097d25ce649db54fef7b4f0ff8c Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 28 May 2024 09:45:19 -0700 Subject: [PATCH 10/11] rename bundle_rec to recognition; add recognition specific tests folder; other minor changes --- AFQ/_fixes.py | 17 ++- AFQ/api/participant.py | 2 +- AFQ/{bundle_rec => recognition}/__init__.py | 0 AFQ/{bundle_rec => recognition}/cleaning.py | 15 +-- AFQ/{bundle_rec => recognition}/criteria.py | 0 AFQ/{bundle_rec => recognition}/curvature.py | 0 AFQ/{bundle_rec => recognition}/preprocess.py | 0 AFQ/{bundle_rec => recognition}/recognize.py | 7 +- AFQ/{bundle_rec => recognition}/roi.py | 0 AFQ/recognition/tests/__init__.py | 0 .../tests/test_recognition.py} | 103 +----------------- .../tests/test_rois.py} | 49 ++++++++- AFQ/recognition/tests/test_utils.py | 75 +++++++++++++ AFQ/{bundle_rec => recognition}/utils.py | 0 AFQ/tasks/segmentation.py | 2 +- AFQ/tests/test_fixes.py | 4 +- AFQ/utils/bin.py | 8 +- AFQ/viz/plot.py | 12 +- 18 files changed, 160 insertions(+), 134 deletions(-) rename AFQ/{bundle_rec => recognition}/__init__.py (100%) rename AFQ/{bundle_rec => recognition}/cleaning.py (93%) rename AFQ/{bundle_rec => recognition}/criteria.py (100%) rename AFQ/{bundle_rec => recognition}/curvature.py (100%) rename AFQ/{bundle_rec => recognition}/preprocess.py (100%) rename AFQ/{bundle_rec => recognition}/recognize.py (98%) rename AFQ/{bundle_rec => recognition}/roi.py (100%) create mode 100644 AFQ/recognition/tests/__init__.py rename AFQ/{tests/test_segmentation.py => recognition/tests/test_recognition.py} (69%) rename AFQ/{tests/test_roi.py => recognition/tests/test_rois.py} (51%) create mode 100644 AFQ/recognition/tests/test_utils.py rename AFQ/{bundle_rec => recognition}/utils.py (100%) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index cf8c03e88..b7d1cba1a 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -1,4 +1,5 @@ import numpy as np +import logging from scipy.special import lpmv, gammaln @@ -12,6 +13,9 @@ from scipy.linalg import blas +logger = logging.getLogger('AFQ') + + def gwi_odf(gqmodel, data): gqi_vector = np.real( squared_radial_component(np.dot( @@ -153,7 +157,7 @@ def tensor_odf(evals, evecs, sphere, num_batches=100): def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, - stat=np.mean, resample=True): + stat=np.mean): """ Calculate weights for each streamline/node in a bundle, based on a Mahalanobis distance from the core the bundle, at that node (mean, per @@ -163,9 +167,9 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, ---------- bundle : Streamlines The streamlines to weight. - n_points : int, optional - The number of points to resample to. *If the `bundle` is an array, - this input is ignored*. Default: 100. + n_points : int or None, optional + The number of points to resample to. If this is None, we assume bundle + is already resampled, and do not do any resampling. Default: 100. return_mahalanobis : bool, optional Whether to return the Mahalanobis distance instead of the weights. Default: False. @@ -185,7 +189,7 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, coordinates at that node position across streamlines. """ - if resample: + if n_points is not None: if isinstance(bundle, np.ndarray): bundle = bundle.tolist() if isinstance(bundle, list): @@ -198,6 +202,9 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False, if n_sls < 15: # Cov^-1 unstable under this amount weights = np.ones((n_sls, n_nodes)) + logger.warning(( + "Not enough streamlines for weight calculation, " + "weighting everything evenly")) if return_mahalnobis: return np.full((n_sls, n_nodes), np.nan) else: diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 9b87bda83..3fb315cca 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -365,7 +365,7 @@ def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], if isinstance(file_name, str): exception_file_names.append(file_name) else: - self.logger.warn(( + self.logger.warning(( f"The exception '{exception}' does not correspond" " to a filename and will be ignored.")) diff --git a/AFQ/bundle_rec/__init__.py b/AFQ/recognition/__init__.py similarity index 100% rename from AFQ/bundle_rec/__init__.py rename to AFQ/recognition/__init__.py diff --git a/AFQ/bundle_rec/cleaning.py b/AFQ/recognition/cleaning.py similarity index 93% rename from AFQ/bundle_rec/cleaning.py rename to AFQ/recognition/cleaning.py index 498134293..2415a27b2 100644 --- a/AFQ/bundle_rec/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -6,7 +6,7 @@ import dipy.tracking.streamline as dts from dipy.io.stateful_tractogram import StatefulTractogram, Space -import AFQ.bundle_rec.utils as abu +import AFQ.recognition.utils as abu from AFQ._fixes import gaussian_weights @@ -25,12 +25,7 @@ def clean_by_orientation(streamlines, primary_axis, tol=None): Returns ------- - cleaned_idx, indicies of streamlines that passed cleaning, - logical_and of other two returns - along_accepted_idx, indices of streamlines that passed - cleaning along the bundle - end_accepted_idx, indices of streamlines that passed - cleaning based on difference between endpoints of bundle + cleaned_idx, indicies of streamlines that passed cleaning """ axis_diff = np.zeros((len(streamlines), 3)) endpoint_diff = np.zeros((len(streamlines), 3)) @@ -111,6 +106,9 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, # We don't even bother if there aren't enough streamlines: if len(streamlines) < min_sl: + logger.warning(( + "Mahalanobis cleaning halted early" + " due to low streamline count")) if return_idx: return tg, np.arange(len(streamlines)) else: @@ -136,8 +134,7 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( fgarray, return_mahalnobis=True, - n_points=n_points, stat=stat, - resample=False) + n_points=None, stat=stat) logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}") logger.debug(f"Shape of m_dist: {m_dist.shape}") logger.debug(f"Maximum m_dist: {np.max(m_dist)}") diff --git a/AFQ/bundle_rec/criteria.py b/AFQ/recognition/criteria.py similarity index 100% rename from AFQ/bundle_rec/criteria.py rename to AFQ/recognition/criteria.py diff --git a/AFQ/bundle_rec/curvature.py b/AFQ/recognition/curvature.py similarity index 100% rename from AFQ/bundle_rec/curvature.py rename to AFQ/recognition/curvature.py diff --git a/AFQ/bundle_rec/preprocess.py b/AFQ/recognition/preprocess.py similarity index 100% rename from AFQ/bundle_rec/preprocess.py rename to AFQ/recognition/preprocess.py diff --git a/AFQ/bundle_rec/recognize.py b/AFQ/recognition/recognize.py similarity index 98% rename from AFQ/bundle_rec/recognize.py rename to AFQ/recognition/recognize.py index 890bf305f..0930776f2 100644 --- a/AFQ/bundle_rec/recognize.py +++ b/AFQ/recognition/recognize.py @@ -211,10 +211,9 @@ def recognize( bundle_decisions, np.ones((n_streamlines, 1))), axis=1) bundle_decisions = np.argmax(bundle_decisions, -1) - # We do another round through, so that we can orient all the - # streamlines within a bundle in the same orientation with respect to - # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0 - # to ROI1). + # We do another round through, so that we can: + # 1. Clip streamlines according to ROIs + # 2. Re-orient streamlines logger.info("Re-orienting streamlines to consistent directions") for bundle_idx, bundle in enumerate(bundle_dict.bundle_names): logger.info(f"Processing {bundle}") diff --git a/AFQ/bundle_rec/roi.py b/AFQ/recognition/roi.py similarity index 100% rename from AFQ/bundle_rec/roi.py rename to AFQ/recognition/roi.py diff --git a/AFQ/recognition/tests/__init__.py b/AFQ/recognition/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/AFQ/tests/test_segmentation.py b/AFQ/recognition/tests/test_recognition.py similarity index 69% rename from AFQ/tests/test_segmentation.py rename to AFQ/recognition/tests/test_recognition.py index 41f02b56c..eea3e25f7 100644 --- a/AFQ/tests/test_segmentation.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -13,11 +13,8 @@ import AFQ.data.fetch as afd import AFQ.registration as reg -import AFQ.bundle_rec.cleaning as abc -import AFQ.bundle_rec.curvature as abv -import AFQ.bundle_rec.utils as abu -import AFQ.bundle_rec.roi as abr -from AFQ.bundle_rec.recognize import recognize +import AFQ.recognition.cleaning as abc +from AFQ.recognition.recognize import recognize dpd.fetch_stanford_hardi() @@ -132,43 +129,6 @@ def test_segment_return_idx(): npt.assert_(len(fiber_groups['Right Corticospinal']['idx']) > 0) -def test_segment_sl_curve(): - sl_disp_0 = abv.sl_curve(streamlines[4], 4) - npt.assert_array_almost_equal( - sl_disp_0, - [[-0.236384, -0.763855, 0.60054 ], - [ 0.232594, -0.867859, -0.439 ], - [ 0.175343, 0.001082, -0.984507]]) - - sl_disp_1 = abv.sl_curve(streamlines[2], 4) - mean_angle_diff = abv.sl_curve_dist(sl_disp_0, sl_disp_1) - npt.assert_almost_equal(mean_angle_diff, 1.701458, decimal=3) - - -def test_segment_clip_edges(): - sls = tg.streamlines - idx = np.arange(len(tg.streamlines)) - accepted_sls = sls[[4, 10, 11]] - accepted_ix = idx[[4, 10, 11]] - bundle_roi_dists = np.zeros((len(sls), 3)) - bundle_roi_dists[4, :] = [5, 10, 15] - bundle_roi_dists[10, :] = [3, 6, 9] - bundle_roi_dists[11, :] = [10, 10, 10] - cut_sls = abu.cut_sls_by_dist( - accepted_sls, - bundle_roi_dists[accepted_ix], - [0, 2]) - npt.assert_array_equal( - cut_sls[0], - accepted_sls[0][5:15]) - npt.assert_array_equal( - cut_sls[1], - accepted_sls[1][3:9]) - npt.assert_array_equal( - cut_sls[2], - accepted_sls[2][9:11]) - - @pytest.mark.nightly def test_segment_clip_edges_api(): # Test with the clip_edges kwarg set to True: @@ -205,50 +165,6 @@ def test_segment_reco(): npt.assert_(len(fiber_groups['CST_R']) > 0) -def test_clean_by_endpoints(): - sl = [np.array([[1, 1, 1], - [2, 1, 1], - [3, 1, 1], - [4, 1, 1]]), - np.array([[1, 1, 2], - [2, 1, 2], - [3, 1, 2], - [4, 1, 2]]), - np.array([[1, 1, 1], - [2, 1, 1], - [3, 1, 1]]), - np.array([[1, 1, 1], - [2, 1, 1]])] - - atlas = np.zeros((20, 20, 20)) - - # Targets: - atlas[1, 1, 1] = 1 - atlas[1, 1, 2] = 2 - atlas[4, 1, 1] = 3 - atlas[4, 1, 2] = 4 - - target_img_start = nib.Nifti1Image( - np.logical_or(atlas==1, atlas==2).astype(np.float32), np.eye(4)) - target_img_end = nib.Nifti1Image( - np.logical_or(atlas==3, atlas==4).astype(np.float32), np.eye(4)) - - clean_idx_start = list(abr.clean_by_endpoints( - sl, target_img_start, 0)) - clean_idx_end = list(abr.clean_by_endpoints( - sl, target_img_end, -1)) - npt.assert_array_equal(np.logical_and( - clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0])) - - # If tol=1, the third streamline also gets included - clean_idx_start = list(abr.clean_by_endpoints( - sl, target_img_start, 0, tol=1)) - clean_idx_end = list(abr.clean_by_endpoints( - sl, target_img_end, -1, tol=1)) - npt.assert_array_equal(np.logical_and( - clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0])) - - def test_exclusion_ROI(): slf_bundle = { 'Left Superior Longitudinal': { @@ -294,21 +210,6 @@ def test_exclusion_ROI(): npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 1) -def test_segment_orientation(): - cleaned_idx = \ - abc.clean_by_orientation(streamlines, primary_axis=1) - npt.assert_equal(np.sum(cleaned_idx), 93) - cleaned_idx_tol = \ - abc.clean_by_orientation(streamlines, primary_axis=1, tol=50) - npt.assert_(np.sum(cleaned_idx_tol) < np.sum(cleaned_idx)) - - cleaned_idx = \ - abc.clean_by_orientation(streamlines, primary_axis=2) - cleaned_idx_tol = \ - abc.clean_by_orientation(streamlines, primary_axis=2, tol=33) - npt.assert_array_equal(cleaned_idx_tol, cleaned_idx) - - def test_segment_sampled_streamlines(): fiber_groups, _ = recognize( tg, diff --git a/AFQ/tests/test_roi.py b/AFQ/recognition/tests/test_rois.py similarity index 51% rename from AFQ/tests/test_roi.py rename to AFQ/recognition/tests/test_rois.py index 2ccb18e33..3b0824613 100644 --- a/AFQ/tests/test_roi.py +++ b/AFQ/recognition/tests/test_rois.py @@ -1,6 +1,9 @@ import numpy as np +import AFQ.recognition.roi as abr +import nibabel as nib +import numpy.testing as npt +import numpy as np from scipy.spatial.distance import cdist - from AFQ.bundle_rec.roi import ( check_sls_with_inclusion, check_sl_with_inclusion, @@ -17,6 +20,50 @@ exclude_roi_tols = [1] +def test_clean_by_endpoints(): + sl = [np.array([[1, 1, 1], + [2, 1, 1], + [3, 1, 1], + [4, 1, 1]]), + np.array([[1, 1, 2], + [2, 1, 2], + [3, 1, 2], + [4, 1, 2]]), + np.array([[1, 1, 1], + [2, 1, 1], + [3, 1, 1]]), + np.array([[1, 1, 1], + [2, 1, 1]])] + + atlas = np.zeros((20, 20, 20)) + + # Targets: + atlas[1, 1, 1] = 1 + atlas[1, 1, 2] = 2 + atlas[4, 1, 1] = 3 + atlas[4, 1, 2] = 4 + + target_img_start = nib.Nifti1Image( + np.logical_or(atlas==1, atlas==2).astype(np.float32), np.eye(4)) + target_img_end = nib.Nifti1Image( + np.logical_or(atlas==3, atlas==4).astype(np.float32), np.eye(4)) + + clean_idx_start = list(abr.clean_by_endpoints( + sl, target_img_start, 0)) + clean_idx_end = list(abr.clean_by_endpoints( + sl, target_img_end, -1)) + npt.assert_array_equal(np.logical_and( + clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0])) + + # If tol=1, the third streamline also gets included + clean_idx_start = list(abr.clean_by_endpoints( + sl, target_img_start, 0, tol=1)) + clean_idx_end = list(abr.clean_by_endpoints( + sl, target_img_end, -1, tol=1)) + npt.assert_array_equal(np.logical_and( + clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0])) + + def test_check_sls_with_inclusion(): sls = [streamline1, streamline2] result = list(check_sls_with_inclusion( diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py new file mode 100644 index 000000000..29d2d2472 --- /dev/null +++ b/AFQ/recognition/tests/test_utils.py @@ -0,0 +1,75 @@ +import numpy.testing as npt +import numpy as np +import nibabel as nib +import os.path as op + +import AFQ.data.fetch as afd +import AFQ.recognition.curvature as abv +import AFQ.recognition.utils as abu +import AFQ.recognition.cleaning as abc + + +from dipy.io.stateful_tractogram import StatefulTractogram, Space +import dipy.data.fetcher as fetcher + + +hardi_dir = op.join(fetcher.dipy_home, "stanford_hardi") +hardi_fdata = op.join(hardi_dir, "HARDI150.nii.gz") +hardi_img = nib.load(hardi_fdata) +file_dict = afd.read_stanford_hardi_tractography() +streamlines = file_dict['tractography_subsampled.trk'] +tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) +tg.to_vox() +streamlines = tg.streamlines + + +def test_segment_sl_curve(): + sl_disp_0 = abv.sl_curve(streamlines[4], 4) + npt.assert_array_almost_equal( + sl_disp_0, + [[-0.236384, -0.763855, 0.60054 ], + [ 0.232594, -0.867859, -0.439 ], + [ 0.175343, 0.001082, -0.984507]]) + + sl_disp_1 = abv.sl_curve(streamlines[2], 4) + mean_angle_diff = abv.sl_curve_dist(sl_disp_0, sl_disp_1) + npt.assert_almost_equal(mean_angle_diff, 1.701458, decimal=3) + + +def test_segment_clip_edges(): + sls = tg.streamlines + idx = np.arange(len(tg.streamlines)) + accepted_sls = sls[[4, 10, 11]] + accepted_ix = idx[[4, 10, 11]] + bundle_roi_dists = np.zeros((len(sls), 3)) + bundle_roi_dists[4, :] = [5, 10, 15] + bundle_roi_dists[10, :] = [3, 6, 9] + bundle_roi_dists[11, :] = [10, 10, 10] + cut_sls = abu.cut_sls_by_dist( + accepted_sls, + bundle_roi_dists[accepted_ix], + [0, 2]) + npt.assert_array_equal( + cut_sls[0], + accepted_sls[0][5:15]) + npt.assert_array_equal( + cut_sls[1], + accepted_sls[1][3:9]) + npt.assert_array_equal( + cut_sls[2], + accepted_sls[2][9:11]) + + +def test_segment_orientation(): + cleaned_idx = \ + abc.clean_by_orientation(streamlines, primary_axis=1) + npt.assert_equal(np.sum(cleaned_idx), 93) + cleaned_idx_tol = \ + abc.clean_by_orientation(streamlines, primary_axis=1, tol=50) + npt.assert_(np.sum(cleaned_idx_tol) < np.sum(cleaned_idx)) + + cleaned_idx = \ + abc.clean_by_orientation(streamlines, primary_axis=2) + cleaned_idx_tol = \ + abc.clean_by_orientation(streamlines, primary_axis=2, tol=33) + npt.assert_array_equal(cleaned_idx_tol, cleaned_idx) diff --git a/AFQ/bundle_rec/utils.py b/AFQ/recognition/utils.py similarity index 100% rename from AFQ/bundle_rec/utils.py rename to AFQ/recognition/utils.py diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 6b979dd57..65be67eeb 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -10,7 +10,7 @@ from AFQ.tasks.decorators import as_file from AFQ.tasks.utils import get_fname, with_name, str_to_desc -from AFQ.bundle_rec.recognize import recognize +from AFQ.recognition.recognize import recognize from AFQ.utils.path import drop_extension, write_json import AFQ.utils.streamlines as aus from AFQ.tasks.utils import get_default_args diff --git a/AFQ/tests/test_fixes.py b/AFQ/tests/test_fixes.py index 7411f1e1f..2007df9d6 100644 --- a/AFQ/tests/test_fixes.py +++ b/AFQ/tests/test_fixes.py @@ -67,6 +67,6 @@ def test_mahal_fix(): [0., 0., 0., 3.415999, 2.689814]]) npt.assert_array_almost_equal( gaussian_weights_fast( - sls_array, n_points=5, + sls_array, n_points=None, return_mahalnobis=True, - stat=np.mean, resample=False)[:3], results) + stat=np.mean)[:3], results) diff --git a/AFQ/utils/bin.py b/AFQ/utils/bin.py index cbfcc3ad4..a73bf70ef 100644 --- a/AFQ/utils/bin.py +++ b/AFQ/utils/bin.py @@ -165,8 +165,8 @@ def dict_to_json(dictionary): def func_dict_to_arg_dict(func_dict=None, logger=None): if func_dict is None: - from AFQ.bundle_rec.recognize import recognize - from AFQ.bundle_rec.cleaning import clean_bundle + from AFQ.recognition.recognize import recognize + from AFQ.recognition.cleaning import clean_bundle import AFQ.tractography.tractography as aft from AFQ.api.group import GroupAFQ @@ -412,8 +412,8 @@ def generate_json(json_folder, overwrite=False, } ] }""" - from AFQ.bundle_rec.recognize import recognize - from AFQ.bundle_rec.cleaning import clean_bundle + from AFQ.recognition.recognize import recognize + from AFQ.recognition.cleaning import clean_bundle import AFQ.tractography.tractography as aft func_dict = { diff --git a/AFQ/viz/plot.py b/AFQ/viz/plot.py index b92ae6fc7..bb8ca8158 100644 --- a/AFQ/viz/plot.py +++ b/AFQ/viz/plot.py @@ -487,7 +487,7 @@ def _get_profile(self, name, bundle, subject, scalar): + ' for bundle ' + bundle + ' for CSV ' + name) if np.sum(nans) > self.percent_nan_tol: - self.logger.warn(message + '. Profile ignored. ') + self.logger.warning(message + '. Profile ignored. ') return None else: self.logger.info(message + '. NaNs interpolated. ') @@ -545,7 +545,7 @@ def masked_corr(self, arr, corrtype): raters="raters", ratings="ratings") row = stats[stats["Type"] == self.ICC_func].iloc[0] - return row["ICC"], row["ICC"] - row["CI95%"][0],\ + return row["ICC"], row["ICC"] - row["CI95%"][0], \ row["CI95%"][1] - row["ICC"] elif corrtype == "Srho": stats = corr( @@ -553,7 +553,7 @@ def masked_corr(self, arr, corrtype): y=arr[1], method="spearman") row = stats.iloc[0] - return row["r"], row["r"] - row["CI95%"][0],\ + return row["r"], row["r"] - row["CI95%"][0], \ row["CI95%"][1] - row["r"] else: raise ValueError("corrtype not recognized") @@ -978,7 +978,7 @@ def reliability_plots(self, names=None, bundle_profiles[j, i] = single_profile all_sub_means[m, k] = np.nanmean(bundle_profiles, axis=2) - all_sub_coef[m, k], all_sub_coef_err[m, k, 0],\ + all_sub_coef[m, k], all_sub_coef_err[m, k, 0], \ all_sub_coef_err[m, k, 1] =\ self.masked_corr(all_sub_means[m, k], "Srho") if np.isnan(all_sub_coef[m, k]).all(): @@ -1359,8 +1359,8 @@ def reliability_plots(self, names=None, if not show_plots: plt.close(fig) plt.ion() - return fig, axes, miss_counts, updated_bundles,\ - all_sub_coef_removed, all_sub_coef_err_removed,\ + return fig, axes, miss_counts, updated_bundles, \ + all_sub_coef_removed, all_sub_coef_err_removed, \ bundle_prof_means_removed, bundle_prof_stds_removed def compare_reliability(self, reliability1, reliability2, From 09b57b4c8626bf9380001587fdccfeba2a2987b1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 28 May 2024 09:49:22 -0700 Subject: [PATCH 11/11] update imports and formatting changes --- AFQ/recognition/criteria.py | 8 ++++---- AFQ/recognition/preprocess.py | 2 +- AFQ/recognition/recognize.py | 6 +++--- AFQ/recognition/tests/test_rois.py | 6 +++--- AFQ/recognition/tests/test_utils.py | 6 +++--- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index d078e246a..f53cf7b29 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -11,10 +11,10 @@ from dipy.segment.bundles import RecoBundles from dipy.io.stateful_tractogram import StatefulTractogram, Space -import AFQ.bundle_rec.utils as abu -import AFQ.bundle_rec.cleaning as abc -import AFQ.bundle_rec.curvature as abv -import AFQ.bundle_rec.roi as abr +import AFQ.recognition.utils as abu +import AFQ.recognition.cleaning as abc +import AFQ.recognition.curvature as abv +import AFQ.recognition.roi as abr bundle_criterion_order = [ "prob_map", "cross_midline", "start", "end", diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 9beac5b55..f9a2d9978 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -5,7 +5,7 @@ import dipy.tracking.streamline as dts -import AFQ.bundle_rec.utils as abu +import AFQ.recognition.utils as abu logger = logging.getLogger('AFQ') diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 0930776f2..2671ca7c9 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -6,10 +6,10 @@ import dipy.tracking.streamlinespeed as dps from dipy.io.stateful_tractogram import StatefulTractogram, Space -import AFQ.bundle_rec.utils as abu +import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import BundleDict -from AFQ.bundle_rec.criteria import run_bundle_rec_plan -from AFQ.bundle_rec.preprocess import get_preproc_plan +from AFQ.recognition.criteria import run_bundle_rec_plan +from AFQ.recognition.preprocess import get_preproc_plan logger = logging.getLogger('AFQ') diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index 3b0824613..97304fa8a 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -4,7 +4,7 @@ import numpy.testing as npt import numpy as np from scipy.spatial.distance import cdist -from AFQ.bundle_rec.roi import ( +from AFQ.recognition.roi import ( check_sls_with_inclusion, check_sl_with_inclusion, check_sl_with_exclusion) @@ -44,9 +44,9 @@ def test_clean_by_endpoints(): atlas[4, 1, 2] = 4 target_img_start = nib.Nifti1Image( - np.logical_or(atlas==1, atlas==2).astype(np.float32), np.eye(4)) + np.logical_or(atlas == 1, atlas == 2).astype(np.float32), np.eye(4)) target_img_end = nib.Nifti1Image( - np.logical_or(atlas==3, atlas==4).astype(np.float32), np.eye(4)) + np.logical_or(atlas == 3, atlas == 4).astype(np.float32), np.eye(4)) clean_idx_start = list(abr.clean_by_endpoints( sl, target_img_start, 0)) diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py index 29d2d2472..d6033d44e 100644 --- a/AFQ/recognition/tests/test_utils.py +++ b/AFQ/recognition/tests/test_utils.py @@ -27,9 +27,9 @@ def test_segment_sl_curve(): sl_disp_0 = abv.sl_curve(streamlines[4], 4) npt.assert_array_almost_equal( sl_disp_0, - [[-0.236384, -0.763855, 0.60054 ], - [ 0.232594, -0.867859, -0.439 ], - [ 0.175343, 0.001082, -0.984507]]) + [[-0.236384, -0.763855, 0.60054], + [0.232594, -0.867859, -0.439], + [0.175343, 0.001082, -0.984507]]) sl_disp_1 = abv.sl_curve(streamlines[2], 4) mean_angle_diff = abv.sl_curve_dist(sl_disp_0, sl_disp_1)