Skip to content

Commit

Permalink
[ENH, FIX] PCA variance enhancements and consistency improvements (#877)
Browse files Browse the repository at this point in the history
* PCA variance enhancements

* style fix: one line was too long

* Update tedana/decomposition/pca.py

typo

Co-authored-by: Joshua Teves <jbtevespro@gmail.com>

* Update tedana/workflows/tedana.py

Co-authored-by: Joshua Teves <jbtevespro@gmail.com>

Co-authored-by: Joshua Teves <jbtevespro@gmail.com>
  • Loading branch information
handwerkerd and jbteves authored May 13, 2022
1 parent d4406e4 commit e44d488
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 21 deletions.
36 changes: 24 additions & 12 deletions tedana/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def low_mem_pca(data):
Component weight map for each component.
s : (C,) array_like
Variance explained for each component.
varex_norm : array-like, shape (n_components,)
Explained variance ratio.
v : (C x T) array_like
Component timeseries.
"""
Expand All @@ -44,7 +46,8 @@ def low_mem_pca(data):
v = ppca.components_.T
s = ppca.explained_variance_
u = np.dot(np.dot(data, v), np.diag(1.0 / s))
return u, s, v
varex_norm = ppca.explained_variance_ratio_
return u, s, varex_norm, v


def tedpca(
Expand Down Expand Up @@ -96,6 +99,8 @@ def tedpca(
(see Li et al., 2007).
If a float is provided, then it is assumed to represent percentage of variance
explained (0-1) to retain from PCA.
If an int is provided, then it is assumed to be the number of components
to select
Default is 'aic'.
kdaw : :obj:`float`, optional
Dimensionality augmentation weight for Kappa calculations. Must be a
Expand Down Expand Up @@ -201,10 +206,13 @@ def tedpca(
"16187-16192."
)
elif isinstance(algorithm, Number):
alg_str = (
"in which the number of components was determined based on a "
"variance explained threshold"
)
if isinstance(algorithm, float):
alg_str = (
"in which the number of components was determined based on a "
"variance explained threshold"
)
else:
alg_str = "in which the number of components is pre-defined"
else:
alg_str = (
"based on the PCA component estimation with a Moving Average"
Expand Down Expand Up @@ -245,17 +253,16 @@ def tedpca(
comp_ts = ppca.components_.T
varex = ppca.explained_variance_
voxel_comp_weights = np.dot(np.dot(data_z, comp_ts), np.diag(1.0 / varex))
varex_norm = varex / varex.sum()
varex_norm = ppca.explained_variance_ratio_
elif low_mem:
voxel_comp_weights, varex, comp_ts = low_mem_pca(data_z)
varex_norm = varex / varex.sum()
voxel_comp_weights, varex, varex_norm, comp_ts = low_mem_pca(data_z)
else:
ppca = PCA(copy=False, n_components=(n_vols - 1))
ppca.fit(data_z)
comp_ts = ppca.components_.T
varex = ppca.explained_variance_
voxel_comp_weights = np.dot(np.dot(data_z, comp_ts), np.diag(1.0 / varex))
varex_norm = varex / varex.sum()
varex_norm = ppca.explained_variance_ratio_

# Compute Kappa and Rho for PCA comps
required_metrics = [
Expand Down Expand Up @@ -311,10 +318,15 @@ def tedpca(
stabilize=True,
)
else:
alg_str = "variance explained-based" if isinstance(algorithm, Number) else algorithm
if isinstance(algorithm, float):
alg_str = "variance explained-based"
elif isinstance(algorithm, int):
alg_str = "a fixed number of components and no"
else:
alg_str = algorithm
LGR.info(
"Selected {0} components with {1} dimensionality "
"detection".format(comptable.shape[0], alg_str)
f"Selected {comptable.shape[0]} components with {round(100*varex_norm.sum(),2)}% "
f"normalized variance explained using {alg_str} dimensionality detection"
)
comptable["classification"] = "accepted"
comptable["rationale"] = ""
Expand Down
4 changes: 4 additions & 0 deletions tedana/tests/test_workflows_parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def test_check_tedpca_value():
with pytest.raises(ValueError):
check_tedpca_value(1.5, is_parser=False)

with pytest.raises(ValueError):
check_tedpca_value(-1, is_parser=False)

assert check_tedpca_value(0.95) == 0.95
assert check_tedpca_value("0.95") == 0.95
assert check_tedpca_value("mdl") == "mdl"
assert check_tedpca_value(52) == 52
19 changes: 14 additions & 5 deletions tedana/workflows/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@


def check_tedpca_value(string, is_parser=True):
"""Check if argument is a float in range 0-1 or one of a list of strings."""
"""
Check if argument is a float in range (0,1),
an int greater than 1 or one of a list of strings.
"""
valid_options = ("mdl", "aic", "kic", "kundu", "kundu-stabilize")
if string in valid_options:
return string
Expand All @@ -15,12 +18,18 @@ def check_tedpca_value(string, is_parser=True):
try:
floatarg = float(string)
except ValueError:
msg = "Argument to tedpca must be a float or one of: {}".format(", ".join(valid_options))
msg = "Argument to tedpca must be a number or one of: {}".format(", ".join(valid_options))
raise error(msg)

if not (0 <= floatarg <= 1):
raise error("Float argument to tedpca must be between 0 and 1.")
return floatarg
if floatarg != int(floatarg):
if not (0 < floatarg < 1):
raise error("Float argument to tedpca must be between 0 and 1.")
return floatarg
else:
intarg = int(floatarg)
if floatarg < 1:
raise error("Int argument must be greater than 1")
return intarg


def is_valid_file(parser, arg):
Expand Down
11 changes: 7 additions & 4 deletions tedana/workflows/tedana.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def _get_parser():
"process and are ordered from most to least aggressive. "
"Users may also provide a float from 0 to 1, "
"in which case components will be selected based on the "
"cumulative variance explained. "
"Default='mdl'."
"cumulative variance explained or an integer greater than 1"
"in which case the specificed number of components will be"
"selected."
"Default='aic'."
),
default="mdl",
default="aic",
)
optional.add_argument(
"--seed",
Expand Down Expand Up @@ -473,7 +475,8 @@ def tedana_workflow(
if not isinstance(gscontrol, list):
gscontrol = [gscontrol]

# Check value of tedpca *if* it is a float
# Check value of tedpca *if* it is a predefined string,
# a float on [0, 1] or an int >= 1
tedpca = check_tedpca_value(tedpca, is_parser=False)

LGR.info("Loading input data: {}".format([f for f in data]))
Expand Down

0 comments on commit e44d488

Please sign in to comment.