Skip to content

Commit

Permalink
Print optimal number of maPCA components and plot optimization curves (
Browse files Browse the repository at this point in the history
…#839)

* Retrieve optimal number of PCA components and plot optimization curves

* Updated to retrieve variance explained criteria from maPCA

* Added prints for PCA info and tested the plot gets generated correctly

* Trigger tests

* Increase minimum maPCA version

* Updated expected results for three-echo test

* Generate variance explained plot and save cross component metrics

* Update __init__.py

* Save maPCA results into dictionary

* Update minimum maPCA version required

* Fixes numpy int issues

* Pins later mapca version

* Fix bad merge

* Fix explained variance figure

* Removed breakpoint

Co-authored-by: Joshua Teves <joshua.teves@nih.gov>
  • Loading branch information
eurunuela and Joshua Teves committed May 17, 2022
1 parent e44d488 commit fc61dec
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers =
python_requires = >= 3.6
install_requires =
bokeh<2.3.0
mapca~=0.0.1
mapca>=0.0.3
matplotlib
nibabel>=2.5.1
nilearn>=0.7
Expand Down
90 changes: 87 additions & 3 deletions tedana/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import numpy as np
import pandas as pd
from mapca import ma_pca
from mapca import MovingAveragePCA
from scipy import stats
from sklearn.decomposition import PCA

from tedana import io, metrics, utils
from tedana.reporting import pca_results as plot_pca_results
from tedana.selection import kundu_tedpca
from tedana.stats import computefeats2

Expand Down Expand Up @@ -244,9 +245,92 @@ def tedpca(
if algorithm in ["mdl", "aic", "kic"]:
data_img = io.new_nii_like(io_generator.reference_img, utils.unmask(data, mask))
mask_img = io.new_nii_like(io_generator.reference_img, mask.astype(int))
voxel_comp_weights, varex, varex_norm, comp_ts = ma_pca(
data_img, mask_img, algorithm, normalize=True
ma_pca = MovingAveragePCA(criterion=algorithm, normalize=True)
_ = ma_pca.fit_transform(data_img, mask_img)

# Extract results from maPCA
voxel_comp_weights = ma_pca.u_
varex = ma_pca.explained_variance_
varex_norm = ma_pca.explained_variance_ratio_
comp_ts = ma_pca.components_.T
aic = ma_pca.aic_
kic = ma_pca.kic_
mdl = ma_pca.mdl_
varex_90 = ma_pca.varexp_90_
varex_95 = ma_pca.varexp_95_
all_comps = ma_pca.all_

# Extract number of components and variance explained for logging and plotting
n_aic = aic["n_components"]
aic_varexp = np.round(aic["explained_variance_total"], 3)
n_kic = kic["n_components"]
kic_varexp = np.round(kic["explained_variance_total"], 3)
n_mdl = mdl["n_components"]
mdl_varexp = np.round(mdl["explained_variance_total"], 3)
n_varex_90 = varex_90["n_components"]
varex_90_varexp = np.round(varex_90["explained_variance_total"], 3)
n_varex_95 = varex_95["n_components"]
varex_95_varexp = np.round(varex_95["explained_variance_total"], 3)
all_varex = np.round(all_comps["explained_variance_total"], 3)

# Print out the results
LGR.info("Optimal number of components based on different criteria:")
LGR.info(
f"AIC: {n_aic} | KIC: {n_kic} | MDL: {n_mdl} | 90% varexp: {n_varex_90} "
f"| 95% varexp: {n_varex_95}"
)

LGR.info("Explained variance based on different criteria:")
LGR.info(
f"AIC: {aic_varexp}% | KIC: {kic_varexp}% | MDL: {mdl_varexp}% | "
f"90% varexp: {varex_90_varexp}% | 95% varexp: {varex_95_varexp}%"
)

pca_optimization_curves = np.array([aic["value"], kic["value"], mdl["value"]])
pca_criteria_components = np.array(
[
n_aic,
n_kic,
n_mdl,
n_varex_90,
n_varex_95,
]
)

# Plot maPCA optimization curves
LGR.info("Plotting maPCA optimization curves")
plot_pca_results(pca_optimization_curves, pca_criteria_components, all_varex, io_generator)

# Save maPCA results into a dictionary
mapca_results = {
"aic": {
"n_components": n_aic,
"explained_variance_total": aic_varexp,
"curve": aic["value"],
},
"kic": {
"n_components": n_kic,
"explained_variance_total": kic_varexp,
"curve": kic["value"],
},
"mdl": {
"n_components": n_mdl,
"explained_variance_total": mdl_varexp,
"curve": mdl["value"],
},
"varex_90": {
"n_components": n_varex_90,
"explained_variance_total": varex_90_varexp,
},
"varex_95": {
"n_components": n_varex_95,
"explained_variance_total": varex_95_varexp,
},
}

# Save dictionary
io_generator.save_file(mapca_results, "PCA cross component metrics json")

elif isinstance(algorithm, Number):
ppca = PCA(copy=False, n_components=algorithm, svd_solver="full")
ppca.fit(data_z)
Expand Down
4 changes: 4 additions & 0 deletions tedana/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,11 @@ def prep_data_for_json(d) -> dict:
# One of the values in the dict is the problem, need to recurse
v = prep_data_for_json(v)
elif isinstance(v, np.ndarray):
if v.dtype == np.int64 or v.dtype == np.uint64:
v = int(v)
v = v.tolist()
elif isinstance(v, np.int64) or isinstance(v, np.uint64):
v = int(v)
# NOTE: add more special cases for type conversions above this
# comment line as an elif block
d[k] = v
Expand Down
4 changes: 2 additions & 2 deletions tedana/reporting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
"""

from .html_report import generate_report
from .static_figures import comp_figures
from .static_figures import comp_figures, pca_results

__all__ = ["generate_report", "comp_figures"]
__all__ = ["generate_report", "comp_figures", "pca_results"]
141 changes: 141 additions & 0 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,144 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
compplot_name = os.path.join(io_generator.out_dir, "figures", plot_name)
plt.savefig(compplot_name)
plt.close()


def pca_results(criteria, n_components, all_varex, io_generator):
"""
Plot the PCA optimization curve for each criteria, and the variance explained curve.
Parameters
----------
criteria : array-like
AIC, KIC, and MDL optimization values for increasing number of components.
n_components : array-like
Number of optimal components given by each criteria.
io_generator : object
An object containing all the information needed to generate the output.
"""

# Plot the PCA optimization curve for each criteria
plt.figure(figsize=(10, 9))
plt.title("PCA Criteria")
plt.xlabel("PCA components")
plt.ylabel("Arbitrary Units")

# AIC curve
plt.plot(criteria[0, :], color="tab:blue", label="AIC")
# KIC curve
plt.plot(criteria[1, :], color="tab:orange", label="KIC")
# MDL curve
plt.plot(criteria[2, :], color="tab:green", label="MDL")

# Vertical line depicting the optimal number of components given by AIC
plt.vlines(
n_components[0],
ymin=np.min(criteria),
ymax=np.max(criteria),
color="tab:blue",
linestyles="dashed",
)
# Vertical line depicting the optimal number of components given by KIC
plt.vlines(
n_components[1],
ymin=np.min(criteria),
ymax=np.max(criteria),
color="tab:orange",
linestyles="dashed",
)
# Vertical line depicting the optimal number of components given by MDL
plt.vlines(
n_components[2],
ymin=np.min(criteria),
ymax=np.max(criteria),
color="tab:green",
linestyles="dashed",
)
# Vertical line depicting the optimal number of components for 90% variance explained
plt.vlines(
n_components[3],
ymin=np.min(criteria),
ymax=np.max(criteria),
color="tab:red",
linestyles="dashed",
label="90% varexp",
)
# Vertical line depicting the optimal number of components for 95% variance explained
plt.vlines(
n_components[4],
ymin=np.min(criteria),
ymax=np.max(criteria),
color="tab:purple",
linestyles="dashed",
label="95% varexp",
)

plt.legend()

#  Save the plot
plot_name = "pca_criteria.png"
pca_criteria_name = os.path.join(io_generator.out_dir, "figures", plot_name)
plt.savefig(pca_criteria_name)
plt.close()

# Plot the variance explained curve
plt.figure(figsize=(10, 9))
plt.title("Variance Explained")
plt.xlabel("PCA components")
plt.ylabel("Variance Explained")

plt.plot(all_varex, color="black", label="Variance Explained")

# Vertical line depicting the optimal number of components given by AIC
plt.vlines(
n_components[0],
ymin=0,
ymax=1,
color="tab:blue",
linestyles="dashed",
label="AIC",
)
# Vertical line depicting the optimal number of components given by KIC
plt.vlines(
n_components[1],
ymin=0,
ymax=1,
color="tab:orange",
linestyles="dashed",
label="KIC",
)
# Vertical line depicting the optimal number of components given by MDL
plt.vlines(
n_components[2],
ymin=0,
ymax=1,
color="tab:green",
linestyles="dashed",
label="MDL",
)
# Vertical line depicting the optimal number of components for 90% variance explained
plt.vlines(
n_components[3],
ymin=0,
ymax=1,
color="tab:red",
linestyles="dashed",
label="90% varexp",
)
# Vertical line depicting the optimal number of components for 95% variance explained
plt.vlines(
n_components[4],
ymin=0,
ymax=1,
color="tab:purple",
linestyles="dashed",
label="95% varexp",
)

plt.legend()

#  Save the plot
plot_name = "pca_variance_explained.png"
pca_variance_explained_name = os.path.join(io_generator.out_dir, "figures", plot_name)
plt.savefig(pca_variance_explained_name)
plt.close()
4 changes: 4 additions & 0 deletions tedana/resources/config/outputs.json
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@
"orig": "pca_metrics",
"bidsv1.5.0": "desc-PCA_metrics"
},
"PCA cross component metrics json": {
"orig": "pca_cross_component_metrics",
"bidsv1.5.0": "desc-PCA_cross_component_metrics"
},
"ICA decomposition json": {
"orig": "ica_decomposition",
"bidsv1.5.0": "desc-ICA_decomposition"
Expand Down
5 changes: 4 additions & 1 deletion tedana/tests/data/cornell_three_echo_outputs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ desc-tedana_metrics.json
desc-tedana_metrics.tsv
desc-ICA_mixing.tsv
desc-ICA_stat-z_components.nii.gz
desc-PCA_cross_component_metrics.json
desc-PCA_decomposition.json
desc-PCA_metrics.json
desc-PCA_metrics.tsv
Expand Down Expand Up @@ -93,5 +94,7 @@ figures/comp_065.png
figures/comp_066.png
figures/comp_067.png
figures/comp_068.png
figures/pca_criteria.png
figures/pca_variance_explained.png
report.txt
tedana_report.html
tedana_report.html

0 comments on commit fc61dec

Please sign in to comment.