Skip to content

Commit

Permalink
Merge pull request #1132 from 36000/massive_seg_refac
Browse files Browse the repository at this point in the history
[REF] Refactor segmentation code into many files and pimms system
  • Loading branch information
36000 authored May 31, 2024
2 parents 7473e6e + 09b57b4 commit 06865e6
Show file tree
Hide file tree
Showing 31 changed files with 1,786 additions and 1,890 deletions.
55 changes: 30 additions & 25 deletions AFQ/_fixes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import logging

from scipy.special import lpmv, gammaln

Expand All @@ -12,6 +13,9 @@
from scipy.linalg import blas


logger = logging.getLogger('AFQ')


def gwi_odf(gqmodel, data):
gqi_vector = np.real(
squared_radial_component(np.dot(
Expand Down Expand Up @@ -163,17 +167,20 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False,
----------
bundle : Streamlines
The streamlines to weight.
n_points : int, optional
The number of points to resample to. *If the `bundle` is an array,
this input is ignored*. Default: 100.
n_points : int or None, optional
The number of points to resample to. If this is None, we assume bundle
is already resampled, and do not do any resampling. Default: 100.
return_mahalanobis : bool, optional
Whether to return the Mahalanobis distance instead of the weights.
Default: False.
stat : callable, optional.
The statistic used to calculate the central tendency of streamlines in
each node. Can be one of {`np.mean`, `np.median`} or other functions
that have similar API. Default: `np.mean`
resample : bool, optional
Whether its necessary to resample the streamlines to the same number
of points. Only set to False if they are already resampled.
Default: True.
Returns
-------
w : array of shape (n_streamlines, n_points)
Expand All @@ -182,38 +189,36 @@ def gaussian_weights(bundle, n_points=100, return_mahalnobis=False,
coordinates at that node position across streamlines.
"""

# Resample to same length for each streamline
# if necessary
resample = False
if isinstance(bundle, np.ndarray):
if len(bundle.shape) > 2:
if bundle.shape[1] != n_points:
sls = bundle.tolist()
sls = [np.asarray(item) for item in sls]
resample = True
else:
resample = True
if resample:
sls = set_number_of_points(bundle, n_points)
if n_points is not None:
if isinstance(bundle, np.ndarray):
bundle = bundle.tolist()
if isinstance(bundle, list):
bundle = [np.asarray(item) for item in bundle]
sls = np.asarray(set_number_of_points(bundle, n_points))
else:
sls = bundle

# If there's only one fiber here, it gets the entire weighting:
if len(bundle) == 1:
n_sls, n_nodes, _ = sls.shape

if n_sls < 15: # Cov^-1 unstable under this amount
weights = np.ones((n_sls, n_nodes))
logger.warning((
"Not enough streamlines for weight calculation, "
"weighting everything evenly"))
if return_mahalnobis:
return np.array([np.nan])
return np.full((n_sls, n_nodes), np.nan)
else:
return np.array([1])

n_sls, n_nodes, n_dim = sls.shape
weights = np.zeros((n_sls, n_nodes))
return weights / np.sum(weights, 0)
else:
weights = np.zeros((n_sls, n_nodes))
diff = stat(sls, axis=0) - sls
for i in range(n_nodes):
# This should come back as a 3D covariance matrix with the spatial
# variance covariance of this node across the different streamlines,
# reorganized as an upper diagonal matrix for expected Mahalanobis
cov = np.cov(sls[:, i, :].T, ddof=0)
while np.any(np.linalg.eigvals(cov) < 0):
cov += np.eye(cov.shape[0]) * 1e-12

# calculate Mahalanobis for node in every fiber
if np.any(cov > 0):
Expand Down
Loading

0 comments on commit 06865e6

Please sign in to comment.