Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update main figure 3 #55

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions 2.evaluate_model/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Evaluate NF1 Model
After training the NF1 model described in [1.train_models]("../1.train_models), and saving the results, we evaluate the performance of the NF1 model.
We evaluate the NF1 model on each split (train, validation, and test), the entire dataset, and each plate using the following metrics:
We evaluate the final NF1 model on each split (train, validation, and test), each plate, and across all plates using the following metrics:

- Precision
- Recall
- Accuracy
- F1 score
- Confusion matrices
- Threshold precision and recall scores

> **NOTE:** The precision and recall data cover the results from all the different parameter settings tested during the hyperparameter search. All other files contain the results from the final model (best hyperparameters).
jenna-tomkinson marked this conversation as resolved.
Show resolved Hide resolved

In addition to these changes, we save the feature importances of the model to gather insights about key morphology differences.
Binary file modified 3.figures/figures/main_figure_3_model_eval.png
jenna-tomkinson marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
659 changes: 195 additions & 464 deletions 3.figures/main_figure_3/main_figure_3.ipynb

Large diffs are not rendered by default.

178 changes: 16 additions & 162 deletions 3.figures/main_figure_3/scripts/main_figure_3.r
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ output_main_figure_3 <- file.path(
figure_dir, "main_figure_3_model_eval.png"
)
results_dir <- file.path(
"../../2.evaluate_models/classify_genotypes/model_evaluation_data"
"../../2.evaluate_model/model_evaluation_data/"
)

# Load data
PR_results_file <- file.path(results_dir, "precision_recall_hyperparameters.parquet")
# Load data (includes optimization in this file)
PR_results_file <- file.path(results_dir, "precision_recall_final_model.parquet")

PR_results_df <- arrow::read_parquet(PR_results_file)

Expand All @@ -35,43 +35,6 @@ PR_results_df <- PR_results_df %>%
dim(PR_results_df)
head(PR_results_df)

# width <- 17
# height <- 12
# options(repr.plot.width = width, repr.plot.height = height)

# pr_by_plate_plot <- (
# ggplot(PR_results_df, aes(x = recall, y = precision, color = datasplit, linetype = shuffled_type))
# + geom_line(aes(linetype = shuffled_type), linewidth = 1)
# + facet_wrap(~plate)
# + theme_bw()
# + labs(color = "ML model\ndata split", linetype = "Features shuffled", x = "Recall", y = "Precision")
# # change the colors
# + scale_color_manual(values = c(
# "test" = brewer.pal(8, "Dark2")[6],
# "train" = brewer.pal(8, "Dark2")[3],
# "val" = brewer.pal(8, "Dark2")[2]
# ))
# + coord_fixed()
# # change the line thickness of the lines in the legend
# + guides(linetype = guide_legend(override.aes = list(size = 1)))
# # change the text size
# + theme(
# strip.text = element_text(size = 18),
# # x and y axis text size
# axis.text.x = element_text(size = 18),
# axis.text.y = element_text(size = 18),
# # x and y axis title size
# axis.title.x = element_text(size = 18),
# axis.title.y = element_text(size = 18),
# # legend text size
# legend.text = element_text(size = 18),
# legend.title = element_text(size = 18),
# )
# )

# pr_by_plate_plot


# Filter only rows with 'all_plates' in the 'plate' column
filtered_all_plates_pr_df <- PR_results_df[PR_results_df$plate == "all_plates", ]

Expand All @@ -90,6 +53,7 @@ pr_all_plates_plot <- (
"Train" = brewer.pal(8, "Dark2")[3],
"Val" = brewer.pal(8, "Dark2")[2]
))
+ scale_y_continuous(limits = c(0, 1))
# change the line thickness of the lines in the legend
+ guides(linetype = guide_legend(override.aes = list(size = 1)))
# change the text size
Expand All @@ -113,6 +77,10 @@ metrics_results_file <- file.path(results_dir, "metrics_final_model.parquet")

metrics_results_df <- arrow::read_parquet(metrics_results_file)

# Filter out rows where datasplit is "val" or "shuffled_val"
metrics_results_df <- metrics_results_df %>%
filter(!(datasplit %in% c("val", "shuffled_val")))

dim(metrics_results_df)
head(metrics_results_df)

Expand All @@ -124,57 +92,11 @@ metrics_results_df$datasplit <- sub("^shuffled_", "", metrics_results_df$dataspl

# Rename "data splits for interpretation
metrics_results_df <- metrics_results_df %>%
mutate(datasplit = recode(datasplit, "test" = "Test", "rest" = "Train"))
mutate(datasplit = recode(datasplit, "test" = "Test", "train" = "Train"))

dim(metrics_results_df)
head(metrics_results_df)

# set plot size
width <- 10
height <- 8
options(repr.plot.width = width, repr.plot.height = height)
# bar plot of the accuracy scores
accuracy_score_per_plate_plot <- (
ggplot(metrics_results_df, aes(x = shuffled_type, y = accuracy, fill = datasplit))
+ geom_bar(stat = "identity", position = "dodge")

# Add text labels on top of bars
+ geom_text(
aes(label = sprintf("%.2f", accuracy)),
position = position_dodge(width = 0.9),
vjust = -0.5,
size = 6
)

+ ylim(0, 1)
+ facet_wrap(~plate)
+ theme_bw()
+ ylab("Accuracy")
+ xlab("Features shuffled")
# change the legend title
+ labs(fill = "ML model\ndata split")
# change the colours
+ scale_fill_manual(values = c(
"Test" = brewer.pal(8, "Dark2")[6],
"Train" = brewer.pal(8, "Dark2")[3]
))
# change the text size
+ theme(
strip.text = element_text(size = 16),
# x and y axis text size
axis.text.x = element_text(size = 16),
axis.text.y = element_text(size = 16),
# x and y axis title size
axis.title.x = element_text(size = 16),
axis.title.y = element_text(size = 16),
# legend text size
legend.text = element_text(size = 16),
legend.title = element_text(size = 16),
)
)

accuracy_score_per_plate_plot

filtered_metrics_df <- metrics_results_df[metrics_results_df$plate == "all_plates", ]

width <- 10
Expand Down Expand Up @@ -225,6 +147,10 @@ CM_results_file <- file.path(results_dir, "confusion_matrix_final_model.parquet"

CM_results_df <- arrow::read_parquet(CM_results_file)

# Filter out rows where datasplit is "val" or "shuffled_val"
CM_results_df <- CM_results_df %>%
filter(!(datasplit %in% c("val", "shuffled_val")))

dim(CM_results_df)
head(CM_results_df)

Expand All @@ -236,7 +162,7 @@ CM_results_df$datasplit <- sub("^shuffled_", "", CM_results_df$datasplit)

# Rename "data splits for interpretation
CM_results_df <- CM_results_df %>%
mutate(datasplit = recode(datasplit, "test" = "Test", "rest" = "Train"))
mutate(datasplit = recode(datasplit, "test" = "Test", "train" = "Train"))

dim(CM_results_df)
head(CM_results_df)
Expand All @@ -251,74 +177,6 @@ CM_results_df <- CM_results_df %>%
dim(CM_results_df)
head(CM_results_df)

# Filter out rows with 'val' in the 'datasplit' column and 'shuffled' in the 'data_type' column
filtered_CM_df <- CM_results_df[!(CM_results_df$shuffled_type == "TRUE"), ]

# plot dimensions
width <- 14
height <- 11
options(repr.plot.width = width, repr.plot.height = height)
# plot a confusion matrix
confusion_matrix_per_plate_final_plot <- (
ggplot(filtered_CM_df, aes(x = factor(true_genotype, levels = rev(levels(factor(true_genotype)))), y = predicted_genotype)) +
facet_grid(plate ~ datasplit) +
geom_point(aes(color = ratio), size = 20, shape = 15) +
geom_text(aes(label = confusion_values)) +
scale_color_gradient("Ratio", low = "white", high = "red", limits = c(0, 1)) +
theme_bw() +
ylab("Predicted genotype") +
xlab("True genotype") +
# change the text size
theme(
strip.text = element_text(size = 16),
# x and y axis text size
axis.text.x = element_text(size = 16),
axis.text.y = element_text(size = 16),
# x and y axis title size
axis.title.x = element_text(size = 16),
axis.title.y = element_text(size = 16),
# legend text size
legend.text = element_text(size = 16),
legend.title = element_text(size = 16),
)
)

confusion_matrix_per_plate_final_plot

# Filter out rows with 'final' in the 'data_type' column
filtered_CM_df <- CM_results_df[!(CM_results_df$shuffled_type == "FALSE"), ]

# plot dimensions
width <- 14
height <- 11
options(repr.plot.width = width, repr.plot.height = height)
# plot a confusion matrix
confusion_matrix_per_plate_shuffled_plot <- (
ggplot(filtered_CM_df, aes(x = factor(true_genotype, levels = rev(levels(factor(true_genotype)))), y = predicted_genotype)) +
facet_grid(plate ~ datasplit) +
geom_point(aes(color = ratio), size = 28, shape = 15) +
geom_text(aes(label = confusion_values)) +
scale_color_gradient("Ratio", low = "white", high = "red", limits = c(0, 1)) +
theme_bw() +
ylab("Predicted genotype") +
xlab("True genotype") +
# change the text size
theme(
strip.text = element_text(size = 16),
# x and y axis text size
axis.text.x = element_text(size = 16),
axis.text.y = element_text(size = 16),
# x and y axis title size
axis.title.x = element_text(size = 16),
axis.title.y = element_text(size = 16),
# legend text size
legend.text = element_text(size = 16),
legend.title = element_text(size = 16),
)
)

confusion_matrix_per_plate_shuffled_plot

# Filter only rows with plate with "all_plates"
filtered_CM_df <- CM_results_df[(CM_results_df$plate == "all_plates"), ]

Expand Down Expand Up @@ -367,15 +225,11 @@ align_plot <- (
free(pr_all_plates_plot) |
confusion_matrix_all_plates_plot |
accuracy_score_all_plates_plot
) + plot_layout(widths = c(4,2,2))

align_plot
) + plot_layout(widths = c(3,2,2))

fig_3_gg <- (
align_plot
) + plot_annotation(tag_levels = "A") & theme(plot.tag = element_text(size = 25))

# Save or display the plot
# Save the plot
ggsave(output_main_figure_3, plot = fig_3_gg, dpi = 500, height = 6, width = 22)

fig_3_gg