Skip to content

Commit

Permalink
Refactor metrics.dependence module (#1088)
Browse files Browse the repository at this point in the history
* Add type hints to metric functions.

* Use keyword arguments.

* Update tests.

* Update dependence.py

* Update collect.py

* Fix other stuff.
  • Loading branch information
tsalo committed Aug 7, 2024
1 parent 18a408e commit 811b9fd
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 83 deletions.
122 changes: 74 additions & 48 deletions tedana/metrics/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def generate_metrics(
metric_maps = {}
if "map weight" in required_metrics:
LGR.info("Calculating weight maps")
metric_maps["map weight"] = dependence.calculate_weights(data_optcom, mixing)
metric_maps["map weight"] = dependence.calculate_weights(
data_optcom=data_optcom,
mixing=mixing,
)
signs = determine_signs(metric_maps["map weight"], axis=0)
comptable["optimal sign"] = signs
metric_maps["map weight"], mixing = flip_components(
Expand All @@ -157,31 +160,42 @@ def generate_metrics(

if "map optcom betas" in required_metrics:
LGR.info("Calculating parameter estimate maps for optimally combined data")
metric_maps["map optcom betas"] = dependence.calculate_betas(data_optcom, mixing)
metric_maps["map optcom betas"] = dependence.calculate_betas(
data=data_optcom,
mixing=mixing,
)
if io_generator.verbose:
metric_maps["map echo betas"] = dependence.calculate_betas(data_cat, mixing)
metric_maps["map echo betas"] = dependence.calculate_betas(
data=data_cat,
mixing=mixing,
)

if "map percent signal change" in required_metrics:
LGR.info("Calculating percent signal change maps")
# used in kundu v3.2 tree
metric_maps["map percent signal change"] = dependence.calculate_psc(
data_optcom, metric_maps["map optcom betas"]
data_optcom=data_optcom,
optcom_betas=metric_maps["map optcom betas"],
)

if "map Z" in required_metrics:
LGR.info("Calculating z-statistic maps")
metric_maps["map Z"] = dependence.calculate_z_maps(metric_maps["map weight"])
metric_maps["map Z"] = dependence.calculate_z_maps(weights=metric_maps["map weight"])

if io_generator.verbose:
io_generator.save_file(
utils.unmask(metric_maps["map Z"] ** 2, mask),
label + " component weights img",
f"{label} component weights img",
)

if ("map FT2" in required_metrics) or ("map FS0" in required_metrics):
LGR.info("Calculating F-statistic maps")
m_t2, m_s0, p_m_t2, p_m_s0 = dependence.calculate_f_maps(
data_cat, metric_maps["map Z"], mixing, adaptive_mask, tes
data_cat=data_cat,
z_maps=metric_maps["map Z"],
mixing=mixing,
adaptive_mask=adaptive_mask,
tes=tes,
)
metric_maps["map FT2"] = m_t2
metric_maps["map FS0"] = m_s0
Expand All @@ -191,58 +205,73 @@ def generate_metrics(
if io_generator.verbose:
io_generator.save_file(
utils.unmask(metric_maps["map FT2"], mask),
label + " component F-T2 img",
f"{label} component F-T2 img",
)
io_generator.save_file(
utils.unmask(metric_maps["map FS0"], mask),
label + " component F-S0 img",
f"{label} component F-S0 img",
)

if "map Z clusterized" in required_metrics:
LGR.info("Thresholding z-statistic maps")
z_thresh = 1.95
metric_maps["map Z clusterized"] = dependence.threshold_map(
metric_maps["map Z"], mask, ref_img, z_thresh
maps=metric_maps["map Z"],
mask=mask,
ref_img=ref_img,
threshold=z_thresh,
)

if "map FT2 clusterized" in required_metrics:
LGR.info("Calculating T2* F-statistic maps")
f_thresh, _, _ = getfbounds(len(tes))
metric_maps["map FT2 clusterized"] = dependence.threshold_map(
metric_maps["map FT2"], mask, ref_img, f_thresh
maps=metric_maps["map FT2"],
mask=mask,
ref_img=ref_img,
threshold=f_thresh,
)

if "map FS0 clusterized" in required_metrics:
LGR.info("Calculating S0 F-statistic maps")
f_thresh, _, _ = getfbounds(len(tes))
metric_maps["map FS0 clusterized"] = dependence.threshold_map(
metric_maps["map FS0"], mask, ref_img, f_thresh
maps=metric_maps["map FS0"],
mask=mask,
ref_img=ref_img,
threshold=f_thresh,
)

# Intermediate metrics
if "countsigFT2" in required_metrics:
LGR.info("Counting significant voxels in T2* F-statistic maps")
comptable["countsigFT2"] = dependence.compute_countsignal(
metric_maps["map FT2 clusterized"]
stat_cl_maps=metric_maps["map FT2 clusterized"],
)

if "countsigFS0" in required_metrics:
LGR.info("Counting significant voxels in S0 F-statistic maps")
comptable["countsigFS0"] = dependence.compute_countsignal(
metric_maps["map FS0 clusterized"]
stat_cl_maps=metric_maps["map FS0 clusterized"],
)

# Back to maps
if "map beta T2 clusterized" in required_metrics:
LGR.info("Thresholding optimal combination beta maps to match T2* F-statistic maps")
metric_maps["map beta T2 clusterized"] = dependence.threshold_to_match(
metric_maps["map optcom betas"], comptable["countsigFT2"], mask, ref_img
maps=metric_maps["map optcom betas"],
n_sig_voxels=comptable["countsigFT2"],
mask=mask,
ref_img=ref_img,
)

if "map beta S0 clusterized" in required_metrics:
LGR.info("Thresholding optimal combination beta maps to match S0 F-statistic maps")
metric_maps["map beta S0 clusterized"] = dependence.threshold_to_match(
metric_maps["map optcom betas"], comptable["countsigFS0"], mask, ref_img
maps=metric_maps["map optcom betas"],
n_sig_voxels=comptable["countsigFS0"],
mask=mask,
ref_img=ref_img,
)

# Dependence metrics
Expand All @@ -258,24 +287,23 @@ def generate_metrics(
if "variance explained" in required_metrics:
LGR.info("Calculating variance explained")
comptable["variance explained"] = dependence.calculate_varex(
metric_maps["map optcom betas"]
optcom_betas=metric_maps["map optcom betas"],
)

if "normalized variance explained" in required_metrics:
LGR.info("Calculating normalized variance explained")
comptable["normalized variance explained"] = dependence.calculate_varex_norm(
metric_maps["map weight"]
weights=metric_maps["map weight"],
)

# Spatial metrics
if "dice_FT2" in required_metrics:
LGR.info(
"Calculating DSI between thresholded T2* F-statistic and "
"optimal combination beta maps"
"Calculating DSI between thresholded T2* F-statistic and optimal combination beta maps"
)
comptable["dice_FT2"] = dependence.compute_dice(
metric_maps["map beta T2 clusterized"],
metric_maps["map FT2 clusterized"],
clmaps1=metric_maps["map beta T2 clusterized"],
clmaps2=metric_maps["map FT2 clusterized"],
axis=0,
)

Expand All @@ -285,20 +313,18 @@ def generate_metrics(
"optimal combination beta maps"
)
comptable["dice_FS0"] = dependence.compute_dice(
metric_maps["map beta S0 clusterized"],
metric_maps["map FS0 clusterized"],
clmaps1=metric_maps["map beta S0 clusterized"],
clmaps2=metric_maps["map FS0 clusterized"],
axis=0,
)

if "signal-noise_t" in required_metrics:
LGR.info("Calculating signal-noise t-statistics")
RepLGR.info(
"A t-test was performed between the distributions of T2*-model "
"F-statistics associated with clusters (i.e., signal) and "
"non-cluster voxels (i.e., noise) to generate a t-statistic "
"(metric signal-noise_z) and p-value (metric signal-noise_p) "
"measuring relative association of the component to signal "
"over noise."
"A t-test was performed between the distributions of T2*-model F-statistics "
"associated with clusters (i.e., signal) and non-cluster voxels (i.e., noise) to "
"generate a t-statistic (metric signal-noise_z) and p-value (metric signal-noise_p) "
"measuring relative association of the component to signal over noise."
)
(
comptable["signal-noise_t"],
Expand All @@ -312,20 +338,18 @@ def generate_metrics(
if "signal-noise_z" in required_metrics:
LGR.info("Calculating signal-noise z-statistics")
RepLGR.info(
"A t-test was performed between the distributions of T2*-model "
"F-statistics associated with clusters (i.e., signal) and "
"non-cluster voxels (i.e., noise) to generate a z-statistic "
"(metric signal-noise_z) and p-value (metric signal-noise_p) "
"measuring relative association of the component to signal "
"over noise."
"A t-test was performed between the distributions of T2*-model F-statistics "
"associated with clusters (i.e., signal) and non-cluster voxels (i.e., noise) to "
"generate a z-statistic (metric signal-noise_z) and p-value (metric signal-noise_p) "
"measuring relative association of the component to signal over noise."
)
(
comptable["signal-noise_z"],
comptable["signal-noise_p"],
) = dependence.compute_signal_minus_noise_z(
Z_maps=metric_maps["map Z"],
Z_clmaps=metric_maps["map Z clusterized"],
F_T2_maps=metric_maps["map FT2"],
z_maps=metric_maps["map Z"],
z_clmaps=metric_maps["map Z clusterized"],
f_t2_maps=metric_maps["map FT2"],
)

if "countnoise" in required_metrics:
Expand All @@ -335,18 +359,19 @@ def generate_metrics(
"calculated for each component."
)
comptable["countnoise"] = dependence.compute_countnoise(
metric_maps["map Z"], metric_maps["map Z clusterized"]
stat_maps=metric_maps["map Z"],
stat_cl_maps=metric_maps["map Z clusterized"],
)

# Composite metrics
if "d_table_score" in required_metrics:
LGR.info("Calculating decision table score")
comptable["d_table_score"] = dependence.generate_decision_table_score(
comptable["kappa"],
comptable["dice_FT2"],
comptable["signal-noise_t"],
comptable["countnoise"],
comptable["countsigFT2"],
kappa=comptable["kappa"],
dice_ft2=comptable["dice_FT2"],
signal_minus_noise_t=comptable["signal-noise_t"],
countnoise=comptable["countnoise"],
countsig_ft2=comptable["countsigFT2"],
)

# External regressor-based metrics
Expand All @@ -368,6 +393,7 @@ def generate_metrics(
write_t2s0 = "map predicted T2" in metric_maps
if write_betas:
betas = metric_maps["map echo betas"]

if write_t2s0:
pred_t2_maps = metric_maps["map predicted T2"]
pred_s0_maps = metric_maps["map predicted S0"]
Expand All @@ -377,22 +403,22 @@ def generate_metrics(
echo_betas = betas[:, i_echo, :]
io_generator.save_file(
utils.unmask(echo_betas, mask),
"echo weight " + label + " map split img",
f"echo weight {label} map split img",
echo=(i_echo + 1),
)

if write_t2s0:
echo_pred_t2_maps = pred_t2_maps[:, i_echo, :]
io_generator.save_file(
utils.unmask(echo_pred_t2_maps, mask),
"echo T2 " + label + " split img",
f"echo T2 {label} split img",
echo=(i_echo + 1),
)

echo_pred_s0_maps = pred_s0_maps[:, i_echo, :]
io_generator.save_file(
utils.unmask(echo_pred_s0_maps, mask),
"echo S0 " + label + " split img",
f"echo S0 {label} split img",
echo=(i_echo + 1),
)

Expand Down
Loading

0 comments on commit 811b9fd

Please sign in to comment.