Skip to content

Commit

Permalink
Add robustica method
Browse files Browse the repository at this point in the history
  • Loading branch information
BahmanTahayori authored and Lestropie committed Aug 28, 2023
1 parent 966a262 commit f4eaa3e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 14 deletions.
50 changes: 49 additions & 1 deletion tedana/decomposition/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import logging
import warnings

import sys

import numpy as np
from scipy import stats
from sklearn.decomposition import FastICA
from robustica import RobustICA ####BTBTBT

LGR = logging.getLogger("GENERAL")
RepLGR = logging.getLogger("REPORT")


def tedica(data, n_components, fixed_seed, maxit=500, maxrestart=10):
def tedica(data, n_components, fixed_seed, ica_method="robustica", n_robust_runs=30, maxit=500, maxrestart=10): ####BTBTBTB
"""
Perform ICA on `data` and returns mixing matrix
Expand Down Expand Up @@ -50,6 +53,51 @@ def tedica(data, n_components, fixed_seed, maxit=500, maxrestart=10):
"decompose the dimensionally reduced dataset."
)

if ica_method=='robustica':
mmix, Iq = r_ica(data, n_components, n_robust_runs, maxit)
fixed_seed=-99999
elif ica_method=='fastica':
mmix, fixed_seed=f_ica(data, n_components, fixed_seed, maxit=500, maxrestart=10)
Iq = 0
else:
LGR.warning("The selected ICA method is invalid!")
sys.exit()




return mmix, fixed_seed


def r_ica(data, n_components, n_robust_runs, max_it): ####BTBTBTB:

if n_robust_runs>100:
LGR.warning("The selected n_robust_runs is a very big number!")


RepLGR.info(
"RobustICA package was used for ICA decomposition \\citep{Anglada2022}."
)
rica0 = RobustICA(n_components=n_components, robust_runs=n_robust_runs, whiten='arbitrary-variance',max_iter= max_it,
robust_dimreduce=False, fun='logcosh')
S0, mmix = rica0.fit_transform(data)

q0 = rica0.evaluate_clustering(rica0.S_all, rica0.clustering.labels_, rica0.signs_, rica0.orientation_)


Iq0 = np.array(np.mean(q0.iq))


mmix = stats.zscore(mmix, axis=0)

LGR.info(
"RobustICA with {0} robust runs was used \n"
"The mean index quality is {1}".format(n_robust_runs, Iq0)
)
return mmix, Iq0


def f_ica(data, n_components, fixed_seed, maxit, maxrestart):
if fixed_seed == -1:
fixed_seed = np.random.randint(low=1, high=1000)

Expand Down
60 changes: 47 additions & 13 deletions tedana/workflows/tedana.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _get_parser():
"in which case the specificed number of components will be "
"selected."
),
choices=["mdl", "kic", "aic"],
default="aic",
)
optional.add_argument(
Expand All @@ -164,19 +165,46 @@ def _get_parser():
),
default="kundu",
)
optional.add_argument(#####BTBTBT
"--ica_method",
dest="ica_method",
help=(
"The applied ICA method. If set to fastica the FastICA "
"from sklearn library will be run once with the seed value. "
"robustica will run FastICA n_robust_runs times and and uses "
"clustering methods to overcome the randomness of the FastICA "
"algorithm. If set to robustica the seed value will be ignored."
"If set to fastica n_robust_runs will not be effective."
),
choices=["robustica", "fastica"],
default="robustica",
)
optional.add_argument(
"--seed",
dest="fixed_seed",
metavar="INT",
type=int,
help=(
help=( ##BTBTBT
"Value used for random initialization of ICA "
"algorithm. Set to an integer value for "
"reproducible ICA results. Set to -1 for "
"algorithm when ica_mthods is set to fastica. Set to an integer value for "
"reproducible ICA results with fastica. Set to -1 for "
"varying results across ICA calls. "
),
default=42,
)
optional.add_argument(#####BTBTBT
"--n_robust_runs",
dest="n_robust_runs",
type=int,
help=(
"The number of times robustica will run."
"This is only effective when ica_mthods is "
"set to robustica."

),
##choices=range(2,100),
default=30,
)
optional.add_argument(
"--maxit",
dest="maxit",
Expand Down Expand Up @@ -323,6 +351,8 @@ def tedana_workflow(
fittype="loglin",
combmode="t2s",
tree="kundu",
ica_method="robustica", ########BTBTAdded
n_robust_runs=30,
tedpca="aic",
fixed_seed=42,
maxit=500,
Expand Down Expand Up @@ -385,9 +415,7 @@ def tedana_workflow(
tedpca : {'mdl', 'aic', 'kic', 'kundu', 'kundu-stabilize', float, int}, optional
Method with which to select components in TEDPCA.
If a float is provided, then it is assumed to represent percentage of variance
explained (0-1) to retain from PCA. If an int is provided, it will output
a fixed number of components defined by the integer between 1 and the
number of time points.
explained (0-1) to retain from PCA.
Default is 'aic'.
fixed_seed : :obj:`int`, optional
Value passed to ``mdp.numx_rand.seed()``.
Expand Down Expand Up @@ -639,11 +667,12 @@ def tedana_workflow(
# Perform ICA, calculate metrics, and apply decision tree
# Restart when ICA fails to converge or too few BOLD components found
keep_restarting = True

n_restarts = 0
seed = fixed_seed
while keep_restarting:
mmix, seed = decomposition.tedica(
dd, n_components, seed, maxit, maxrestart=(maxrestart - n_restarts)
dd, n_components, seed, ica_method, n_robust_runs, maxit, maxrestart=(maxrestart - n_restarts)
)
seed += 1
n_restarts = seed - fixed_seed
Expand Down Expand Up @@ -677,13 +706,17 @@ def tedana_workflow(
)
ica_selector = selection.automatic_selection(comptable, n_echos, n_vols, tree=tree)
n_likely_bold_comps = ica_selector.n_likely_bold_comps
if (n_restarts < maxrestart) and (n_likely_bold_comps == 0):
LGR.warning("No BOLD components found. Re-attempting ICA.")
elif n_likely_bold_comps == 0:
LGR.warning("No BOLD components found, but maximum number of restarts reached.")
keep_restarting = False
else:

if ica_method=='robustica': #########BTBTBT
keep_restarting = False
else:
if (n_restarts < maxrestart) and (n_likely_bold_comps == 0):
LGR.warning("No BOLD components found. Re-attempting ICA.")
elif n_likely_bold_comps == 0:
LGR.warning("No BOLD components found, but maximum number of restarts reached.")
keep_restarting = False
else:
keep_restarting = False

# If we're going to restart, temporarily allow force overwrite
if keep_restarting:
Expand Down Expand Up @@ -893,3 +926,4 @@ def _main(argv=None):

if __name__ == "__main__":
_main()

0 comments on commit f4eaa3e

Please sign in to comment.