Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, ENH] Add GODEC #120

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4eda775
Add godec.
tsalo Aug 26, 2018
0f6eb2f
Fix integration test workflow call.
tsalo Aug 26, 2018
6c4abc3
Fix potential bug in T1c and GODEC.
tsalo Aug 28, 2018
60ce2f3
Remove mean from High-Kappa time series.
tsalo Aug 28, 2018
b7119ea
Fix mmix shape bug.
tsalo Aug 28, 2018
dba5c24
Improve gscontrol_mmix docstring.
tsalo Aug 28, 2018
b1851a0
Incorporate "out_dir" variable into workflow and functions.
tsalo Aug 31, 2018
e72af8f
Fix test (ish)
tsalo Aug 31, 2018
a63a00e
Now *really* fix tests. Well, one test.
tsalo Aug 31, 2018
b27dff3
Fix t2smap and associated tests. Update CLI.
tsalo Aug 31, 2018
21b5d98
Refactor tedana workflow. Make more readable.
tsalo Aug 31, 2018
60da7f3
Fix import issue.
tsalo Sep 1, 2018
6910f5c
Merge remote-tracking branch 'ME-ICA/master' into godec
tsalo Sep 3, 2018
c6cfc88
Revert changes to gscontrol_mmix.
tsalo Sep 9, 2018
9fe4123
And... fix output directory after reverting changes..
tsalo Sep 9, 2018
a540da1
Merge branch 'master' into godec
tsalo Sep 18, 2018
cf44965
Fix style issue.
tsalo Sep 18, 2018
5bdfe9b
Merge remote-tracking branch 'ME-ICA/master' into godec
tsalo Nov 28, 2018
68d59ff
Merge remote-tracking branch 'ME-ICA/master' into godec
tsalo Nov 28, 2018
b69ff09
Undo changes to gitignore.
tsalo Nov 28, 2018
d838a4a
Fix.
tsalo Nov 28, 2018
92c2c8c
Revert out_dir changes.
tsalo Nov 28, 2018
08d4fb9
Revert more.
tsalo Nov 28, 2018
bb04f4c
Fix style problem.
tsalo Nov 28, 2018
9a0ff51
Merge remote-tracking branch 'ME-ICA/master' into godec
tsalo Feb 15, 2019
8b03fbb
Fix style issues.
tsalo Feb 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tedana/decomposition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from .eigendecomp import (
tedpca, tedica,
)
from .godecomp import (
tedgodec
)


__all__ = [
'tedpca', 'tedica',
'tedpca', 'tedica', 'tedgodec',
]
17 changes: 11 additions & 6 deletions tedana/decomposition/eigendecomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


def tedpca(catd, OCcatd, combmode, mask, t2s, t2sG, stabilize,
ref_img, tes, kdaw, rdaw, ste=0, mlepca=True, wvpca=False):
ref_img, tes, kdaw, rdaw, ste=0, mlepca=True, wvpca=False,
out_dir='.'):
"""
Use principal components analysis (PCA) to identify and remove thermal
noise from multi-echo data.
Expand Down Expand Up @@ -57,6 +58,8 @@ def tedpca(catd, OCcatd, combmode, mask, t2s, t2sG, stabilize,
guessing PCA dimensionality instead of a traditional SVD. Default: True
wvpca : :obj:`bool`, optional
Whether to apply wavelet denoising to data. Default: False
out_dir : :obj:`str`, optional
Output directory in which to save output files

Returns
-------
Expand Down Expand Up @@ -137,7 +140,7 @@ def tedpca(catd, OCcatd, combmode, mask, t2s, t2sG, stabilize,
if wvpca:
dz, cAl = dwtmat(dz)

if not op.exists('pcastate.pkl'):
if not op.exists(op.join(out_dir, 'pcastate.pkl')):
# do PC dimension selection and get eigenvalue cutoff
if mlepca:
from sklearn.decomposition import PCA
Expand Down Expand Up @@ -177,7 +180,7 @@ def tedpca(catd, OCcatd, combmode, mask, t2s, t2sG, stabilize,
ctb = np.vstack([ctb.T[:3], sp]).T

# Save state
fname = op.abspath('pcastate.pkl')
fname = op.abspath(op.join(out_dir, 'pcastate.pkl'))
LGR.info('Saving PCA results to: {}'.format(fname))
pcastate = {'u': u, 's': s, 'v': v, 'ctb': ctb,
'eigelb': eigelb, 'spmin': spmin, 'spcum': spcum}
Expand All @@ -189,14 +192,16 @@ def tedpca(catd, OCcatd, combmode, mask, t2s, t2sG, stabilize,

else: # if loading existing state
LGR.info('Loading PCA from: pcastate.pkl')
with open('pcastate.pkl', 'rb') as handle:
with open(op.join(out_dir, 'pcastate.pkl'), 'rb') as handle:
pcastate = pickle.load(handle)
u, s, v = pcastate['u'], pcastate['s'], pcastate['v']
ctb, eigelb = pcastate['ctb'], pcastate['eigelb']
spmin, spcum = pcastate['spmin'], pcastate['spcum']

np.savetxt('comp_table_pca.txt', ctb[ctb[:, 1].argsort(), :][::-1])
np.savetxt('mepca_mix.1D', v[ctb[:, 1].argsort()[::-1], :].T)
np.savetxt(op.join(out_dir, 'comp_table_pca.txt'),
ctb[ctb[:, 1].argsort(), :][::-1])
np.savetxt(op.join(out_dir, 'mepca_mix.1D'),
v[ctb[:, 1].argsort()[::-1], :].T)

kappas = ctb[ctb[:, 1].argsort(), 1]
rhos = ctb[ctb[:, 2].argsort(), 2]
Expand Down
184 changes: 184 additions & 0 deletions tedana/decomposition/godecomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Go Decomposition
"""
import logging
import os.path as op

import numpy as np
from numpy.linalg import qr, lstsq

from tedana import utils
from tedana.decomposition._utils import dwtmat, idwtmat

LGR = logging.getLogger(__name__)


def _wthresh(a, thresh):
"""
Soft wavelet threshold
"""
res = np.abs(a) - thresh
return np.sign(a) * ((res > 0) * res)


def godec(data, thresh=.03, rank=2, power=1, tol=1e-3, max_iter=100,
random_seed=0, verbose=True):
"""
Perform Go Decomposition

Default threshold of .03 is assumed to be for input in the range 0-1...
original matlab had 8 out of 255, which is about .03 scaled to 0-1 range

Parameters
----------
data : (M x T) array_like

Returns
-------
L : array_like
Low-rank components. Similar to global signals. Should be discarded,
according to Power et al. (2018).
S : array_like
Sparse components. Should be retained, according to Power et al.
(2018).
G : array_like
Residuals (i.e., data minus sparse and low-rank components)
"""
LGR.info('Starting Go Decomposition')
_, n_vols = data.shape
L = data
S = np.zeros(L.shape)
itr = 0
random_state = np.random.RandomState(random_seed)
while True:
Y2 = random_state.randn(n_vols, rank)
for i in range(power+1):
Y1 = np.dot(L, Y2)
Y2 = np.dot(L.T, Y1)
Q, R = qr(Y2)
L_new = np.dot(np.dot(L, Q), Q.T)
T = L - L_new + S
L = L_new
S = _wthresh(T, thresh)
T -= S
err = np.linalg.norm(T.ravel(), 2)
if err < tol:
if verbose:
LGR.info('Successful convergence after %i iterations', itr+1)
break
elif itr >= max_iter:
if verbose:
LGR.warning('Model failed to converge after %i iterations',
itr+1)
break
L += T
itr += 1

# Is this even useful in soft GoDec? May be a display issue...
G = data - L - S
return L, S, G


def _tedgodec(data, wavelet=False, rank=2, power=2, tol=1e-3,
thresh=10, max_iter=500, norm_mode='vn', random_seed=0,
verbose=True):
"""
Perform TE-dependent Go Decomposition

Parameters
----------
data : (M x T) array_like
wavelet : :obj:`bool`, optional

"""
if norm_mode == 'dm':
# Demean
data_mean = data.mean(-1)
data_norm = data - data_mean[:, np.newaxis]
elif norm_mode == 'vn':
# Variance normalize
data_mean = data.mean(-1)[:, np.newaxis]
data_std = data.std(-1)[:, np.newaxis]
data_norm = (data - data_mean) / data_std
else:
data_norm = data

# GoDec
if wavelet:
data_wt, cal = dwtmat(data_norm)
L, S, G = godec(data_wt, thresh=data_wt.std()*thresh, rank=rank,
power=power, tol=tol, max_iter=max_iter,
random_seed=random_seed, verbose=verbose)
L = idwtmat(L, cal)
S = idwtmat(S, cal)
G = idwtmat(G, cal)
else:
L, S, G = godec(data_norm, thresh=thresh, rank=rank,
power=power, tol=tol, max_iter=max_iter,
random_seed=random_seed, verbose=verbose)

if norm_mode == 'dm':
# Remean
L += data_mean
elif norm_mode == 'vn':
L = (L * data_std) + data_mean
S *= data_std
G *= data_std

return L, S, G


def tedgodec(optcom_ts, mmix, mask, acc, ign, ref_img, ranks=[2],
wavelet=False, thresh=10, norm_mode='vn', power=2, out_dir='.'):
"""
optcom_ts : (S x T) array_like
Optimally combined time series data
mmix : (C x T) array_like
Mixing matrix for converting input data to component space, where `C`
is components and `T` is the same as in `optcom_ts`
mask : (S,) array_like
Boolean mask array
acc : :obj:`list`
Indices of accepted (BOLD) components in `mmix`
ign : :obj:`list`
Indices of all ignored components in `mmix`
ref_img : :obj:`str` or img_like
Reference image to dictate how outputs are saved to disk
ranks : list of int
Ranks of low-rank components to run
norm_mode : {'vn', 'dm', None}
"""
# Construct denoised data from optcom, mmix, acc, and all_ref
optcom_masked = optcom_ts[mask, :]
optcom_mu = optcom_masked.mean(axis=-1)[:, np.newaxis]
optcom_std = optcom_masked.std(axis=-1)[:, np.newaxis]
data_norm = (optcom_masked - optcom_mu) / optcom_std
cbetas = lstsq(mmix, data_norm.T, rcond=None)[0].T
all_comps = np.arange(mmix.shape[0])
not_ign = sorted(np.setdiff1d(all_comps, ign))
resid = data_norm - np.dot(cbetas[:, not_ign], mmix[:, not_ign].T)
bold_ts = np.dot(cbetas[:, acc], mmix[:, acc].T)
medn_ts = optcom_mu + ((bold_ts + resid) * optcom_std)

for rank in ranks:
L, S, G = _tedgodec(medn_ts, rank=rank, power=power, thresh=thresh,
max_iter=500, norm_mode=norm_mode)

if norm_mode is None:
name_norm_mode = ''
else:
name_norm_mode = 'n{0}'.format(norm_mode)

if wavelet:
name_norm_mode = 'w{0}'.format(name_norm_mode)

suffix = '{0}r{1}p{2}t{3}'.format(name_norm_mode, rank, power, thresh)
utils.filewrite(utils.unmask(L, mask),
op.join(out_dir, 'lowrank_{0}.nii'.format(suffix)),
ref_img)
utils.filewrite(utils.unmask(S, mask),
op.join(out_dir, 'sparse_{0}.nii'.format(suffix)),
ref_img)
utils.filewrite(utils.unmask(G, mask),
op.join(out_dir, 'noise_{0}.nii'.format(suffix)),
ref_img)
29 changes: 23 additions & 6 deletions tedana/model/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Fit models.
"""
import logging
import os.path as op

import nilearn.image as niimg
from nilearn._utils import check_niimg
Expand Down Expand Up @@ -371,12 +372,12 @@ def get_coeffs(data, X, mask=None, add_const=False):
return betas


def gscontrol_raw(catd, optcom, n_echos, ref_img, dtrank=4):
def gscontrol_raw(catd, optcom, n_echos, ref_img, out_dir='.', dtrank=4):
"""
Removes global signal from individual echo `catd` and `optcom` time series

This function uses the spatial global signal estimation approach to
to removal global signal out of individual echo time series datasets. The
to remove global signal out of individual echo time series datasets. The
spatial global signal is estimated from the optimally combined data after
detrending with a Legendre polynomial basis of `order = 0` and
`degree = dtrank`.
Expand All @@ -391,6 +392,8 @@ def gscontrol_raw(catd, optcom, n_echos, ref_img, dtrank=4):
Number of echos in data. Should be the same as `E` dimension of `catd`
ref_img : :obj:`str` or img_like
Reference image to dictate how outputs are saved to disk
out_dir : :obj:`str`
Output directory in which to save output files
dtrank : :obj:`int`, optional
Specifies degree of Legendre polynomial basis function for estimating
spatial global signal. Default: 4
Expand All @@ -401,6 +404,19 @@ def gscontrol_raw(catd, optcom, n_echos, ref_img, dtrank=4):
Input `catd` with global signal removed from time series
dm_optcom : (S x T) array_like
Input `optcom` with global signal removed from time series

Notes
-----
This function writes out several files. Files are listed below:

====================== =================================================
Filename Content
====================== =================================================
T1gs.nii Spatial global signal
tsoc_orig.nii Optimally combined data
tsoc_nogs.nii Optimally combined data with global signal
regressed out
====================== =================================================
"""
LGR.info('Applying amplitude-based T1 equilibration correction')
if catd.shape[0] != optcom.shape[0]:
Expand Down Expand Up @@ -429,23 +445,24 @@ def gscontrol_raw(catd, optcom, n_echos, ref_img, dtrank=4):
detr = dat - np.dot(sol.T, Lmix.T)[0]
sphis = (detr).min(axis=1)
sphis -= sphis.mean()
utils.filewrite(utils.unmask(sphis, Gmask), 'T1gs', ref_img)
utils.filewrite(utils.unmask(sphis, Gmask),
op.join(out_dir, 'T1gs.nii'), ref_img)

# find time course ofc the spatial global signal
# make basis with the Legendre basis
glsig = np.linalg.lstsq(np.atleast_2d(sphis).T, dat, rcond=None)[0]
glsig = stats.zscore(glsig, axis=None)
np.savetxt('glsig.1D', glsig)
np.savetxt(op.join(out_dir, 'glsig.1D'), glsig)
glbase = np.hstack([Lmix, glsig.T])

# Project global signal out of optimally combined data
sol = np.linalg.lstsq(np.atleast_2d(glbase), dat.T, rcond=None)[0]
tsoc_nogs = dat - np.dot(np.atleast_2d(sol[dtrank]).T,
np.atleast_2d(glbase.T[dtrank])) + Gmu[Gmask][:, np.newaxis]

utils.filewrite(optcom, 'tsoc_orig', ref_img)
utils.filewrite(optcom, op.join(out_dir, 'tsoc_orig.nii'), ref_img)
dm_optcom = utils.unmask(tsoc_nogs, Gmask)
utils.filewrite(dm_optcom, 'tsoc_nogs', ref_img)
utils.filewrite(dm_optcom, op.join(out_dir, 'tsoc_nogs.nii'), ref_img)

# Project glbase out of each echo
dm_catd = catd.copy() # don't overwrite catd
Expand Down
14 changes: 10 additions & 4 deletions tedana/selection/select_comps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions to identify TE-dependent and TE-independent components.
"""
import os
import os.path as op
import json
import logging
import pickle
Expand All @@ -21,7 +22,8 @@


def selcomps(seldict, mmix, mask, ref_img, manacc, n_echos, t2s, s0, olevel=2,
oversion=99, filecsdata=True, savecsdiag=True, strict_mode=False):
oversion=99, filecsdata=True, savecsdiag=True, strict_mode=False,
out_dir='.'):
"""
Labels ICA components to keep or remove from denoised data

Expand Down Expand Up @@ -60,6 +62,8 @@ def selcomps(seldict, mmix, mask, ref_img, manacc, n_echos, t2s, s0, olevel=2,
Default: True
strict_mode: :obj:`bool`, optional
Default: False
out_dir : :obj:`str`
Output directory in which to save output files

Returns
-------
Expand Down Expand Up @@ -776,7 +780,9 @@ def selcomps(seldict, mmix, mask, ref_img, manacc, n_echos, t2s, s0, olevel=2,
group0_res = np.intersect1d(KRguess, group0)
phys_var_zs.append((vvex - vvex[group0_res].mean()) / vvex[group0_res].std())
veinBout = utils.unmask(veinmaskB, mask)
utils.filewrite(veinBout.astype(float), 'veins_l%i' % t2sl_i, ref_img)
utils.filewrite(veinBout.astype(float),
op.join(out_dir, 'veins_l{}.nii'.format(t2sl_i)),
ref_img)

# Mask to sample veins
phys_var_z = np.array(phys_var_zs).max(0)
Expand Down Expand Up @@ -855,10 +861,10 @@ def selcomps(seldict, mmix, mask, ref_img, manacc, n_echos, t2s, s0, olevel=2,
list(field_art), list(phys_art),
list(misc_art), list(acc_comps), list(ign)]

with open('csstepdata.json', 'w') as ofh:
with open(op.join(out_dir, 'csstepdata.json'), 'w') as ofh:
json.dump(dict(zip(diagstep_keys, diagstep_vals)), ofh,
indent=4, sort_keys=True, default=str)
allfz = np.array([Tz, Vz, Ktz, KRr, cnz, Rz, mmix_kurt, fdist_z])
np.savetxt('csdata.txt', allfz)
np.savetxt(op.join(out_dir, 'csdata.txt'), allfz)

return list(sorted(acc_comps)), list(sorted(rej)), list(sorted(midk)), list(sorted(ign))
Loading