Skip to content

Commit

Permalink
merge pr #82: fixes for workflow sets and elastic net models
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored May 7, 2021
2 parents 039cad0 + 145ca5c commit 35cb85c
Show file tree
Hide file tree
Showing 27 changed files with 294 additions and 78 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Imports:
tibble (>= 2.1.3),
purrr (>= 0.3.2),
parsnip (>= 0.0.4),
workflows (>= 0.2.1.9000),
workflows (>= 0.2.2),
recipes (>= 0.1.15),
rsample (>= 0.0.9),
workflowsets (>= 0.0.0.9001),
Expand Down
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ To be released as 0.2.1

* Various bug fixes and improvements to documentation.

### Bug fixes

* Updates for importing workflow sets that use the `add_variables()`
preprocessor.

* Plot fixes for cases where coefficients are negative.

* Performance and member plots now show the effect of multiple mixture values.

## v0.2.0

### Breaking changes
Expand Down
6 changes: 5 additions & 1 deletion R/add_candidates.R
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,12 @@ stack_workflow <- function(x) {

if (inherits(pre, "formula")) {
res <- res %>% workflows::add_formula(pre)
} else {
} else if (inherits(pre, "recipe")) {
res <- res %>% workflows::add_recipe(pre)
} else if (inherits(pre, "workflow_variables")) {
res <- res %>% workflows::add_variables(variables = pre)
} else {
rlang::abort(paste0("Can't add a preprocessor of class '", class(pre)[1], "'"))
}

res
Expand Down
26 changes: 18 additions & 8 deletions R/blend_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,33 @@ check_regularization <- function(x, arg) {
glmnet_metrics <- function(x) {
res <- tune::collect_metrics(x)
pens <- sort(unique(res$penalty))
x$glmnet_fits <- purrr::map(x$.extracts, ~ .x$.extracts[[1]])
num_mem <-
purrr::map_dfr(x$glmnet_fits, num_members, pens) %>%
dplyr::group_by(penalty) %>%
dplyr::select(x, id, .extracts) %>%
tidyr::unnest(.extracts) %>%
dplyr::group_nest(id, penalty, mixture) %>%
# There are redundant model objects over penalty values
dplyr::mutate(data = purrr::map(data, ~ .x$.extracts[[1]])) %>%
dplyr::mutate(
members = purrr::map(data, ~ num_members(.x, pens))
) %>%
dplyr::select(mixture, members) %>%
tidyr::unnest(cols = members) %>%
dplyr::group_by(penalty, mixture) %>%
dplyr::summarize(
.metric = "num_members",
.estimator = "Poisson",
mean = mean(members, na.rm = TRUE),
n = sum(!is.na(members)),
std_err = sqrt(mean/n)
std_err = sqrt(mean/n),
.groups = "drop"
) %>%
dplyr::ungroup() %>%
dplyr::full_join(
res %>% dplyr::select(penalty, .config) %>% dplyr::distinct(),
by = "penalty"
res %>% dplyr::select(penalty, mixture, .config) %>% dplyr::distinct(),
by = c("penalty", "mixture")
)
dplyr::bind_rows(res, num_mem)

dplyr::bind_rows(res, num_mem) %>%
dplyr::arrange(.config)
}

num_members <- function(x, penalties) {
Expand Down
32 changes: 25 additions & 7 deletions R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ member_plot <- function(x) {

plot_dat <-
dat %>%
dplyr::select(penalty, .config, mean, .metric)
dplyr::select(penalty, mixture, .config, mean, .metric)

memb_data <-
dplyr::filter(plot_dat, .metric == "num_members") %>%
Expand All @@ -57,10 +57,19 @@ member_plot <- function(x) {

other_metrics <- dplyr::filter(plot_dat, .metric != "num_members")

plot_dat <- dplyr::full_join(memb_data, other_metrics, by = c("penalty", ".config"))
plot_dat <- dplyr::full_join(memb_data, other_metrics, by = c("penalty", "mixture", ".config"))

mult_mix <- length(unique(plot_dat$mixture)) > 1

if (mult_mix) {
plot_dat$mixture <- format(plot_dat$mixture)
p <- ggplot2::ggplot(plot_dat, ggplot2::aes(x = num_members, y = mean, col = mixture))
} else {
p <- ggplot2::ggplot(plot_dat, ggplot2::aes(x = num_members, y = mean))
}

p <-
ggplot2::ggplot(plot_dat, ggplot2::aes(x = num_members, y = mean)) +
p +
ggplot2::geom_point() +
ggplot2::facet_wrap(
~.metric,
Expand All @@ -74,8 +83,17 @@ member_plot <- function(x) {

performance_plot <- function(x) {
dat <- x$metrics
mult_mix <- length(unique(dat$mixture)) > 1

if (mult_mix) {
dat$mixture <- format(dat$mixture)
p <- ggplot2::ggplot(dat, ggplot2::aes(x = penalty, y = mean, col = mixture))
} else {
p <- ggplot2::ggplot(dat, ggplot2::aes(x = penalty, y = mean))
}
p <-
ggplot2::ggplot(dat, ggplot2::aes(x = penalty, y = mean)) +
p +
ggplot2::geom_vline(xintercept = x$penalty$penalty, lty = 2) +
ggplot2::geom_point() +
ggplot2::geom_path() +
ggplot2::facet_wrap(~ .metric, scales = "free_y", ncol = 1) +
Expand All @@ -92,7 +110,7 @@ weights_plot <- function(x, penalty = x$penalty$penalty, n = Inf) {
dat_order <-
dat %>%
dplyr::group_by(model, terms) %>%
dplyr::summarize(mean = max(weight, na.rm = TRUE)) %>%
dplyr::summarize(mean = max(abs(weight), na.rm = TRUE)) %>%
dplyr::ungroup() %>%
dplyr::arrange(mean) %>%
dplyr::mutate(member = dplyr::row_number()) %>%
Expand All @@ -101,14 +119,14 @@ weights_plot <- function(x, penalty = x$penalty$penalty, n = Inf) {
} else {
dat <-
dat %>%
dplyr::arrange(weight) %>%
dplyr::arrange(abs(weight)) %>%
dplyr::mutate(member = dplyr::row_number())
}
p <-
ggplot2::ggplot(dat, ggplot2::aes(x = weight, y = format(member), fill = model)) +
ggplot2::geom_bar(stat = "identity") +
ggplot2::ylab("Member") +
ggplot2::ggtitle(paste("penalty =", format(x$coefs$spec$args$penalty, digits = 3))) +
ggplot2::ggtitle(paste("penalty =", format(x$coefs$spec$args$penalty, digits = 3, scientific = FALSE))) +
ggplot2::geom_vline(xintercept = 0) +
ggplot2::xlab("Stacking Coefficient")

Expand Down
4 changes: 2 additions & 2 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ top_coefs <- function(x, penalty = x$penalty$penalty, n = 10) {
res <-
dplyr::left_join(betas, sub_models, by = "terms") %>%
dplyr::left_join(model_types, by = "model_name") %>%
dplyr::top_n(n, estimate) %>%
dplyr::arrange(dplyr::desc(estimate))
dplyr::top_n(n, abs(estimate)) %>%
dplyr::arrange(dplyr::desc(abs(estimate)))

if (any(names(res) == "class")) {
pred_levels <-
Expand Down
3 changes: 3 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ utils::globalVariables(c(
"assess_object",
"coef",
"contains",
"data",
"estimate",
"estimate.x",
"estimate.y",
".extracts",
"id",
"idx",
"lp",
"mem",
"member",
"members",
"mixture",
"model",
"model_type",
"n",
Expand Down
Binary file modified data/class_folds.rda
Binary file not shown.
Binary file modified data/class_res_nn.rda
Binary file not shown.
Binary file modified data/class_res_rf.rda
Binary file not shown.
Binary file modified data/log_res_nn.rda
Binary file not shown.
Binary file modified data/log_res_rf.rda
Binary file not shown.
Binary file modified data/reg_folds.rda
Binary file not shown.
Binary file modified data/reg_res_lr.rda
Binary file not shown.
Binary file modified data/reg_res_sp.rda
Binary file not shown.
Binary file modified data/reg_res_svm.rda
Binary file not shown.
Binary file modified data/tree_frogs_class_test.rda
Binary file not shown.
Binary file modified data/tree_frogs_reg_test.rda
Binary file not shown.
10 changes: 9 additions & 1 deletion man-roxygen/example_models.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ st_class_1 <-
stacks() %>%
add_candidates(class_res_rf)
st_class_1_non_neg <-
st_class_1 %>%
blend_predictions(non_negative = FALSE)
st_class_1_mixed <-
st_class_1 %>%
blend_predictions(non_negative = FALSE, mixture = c(.25, .75))
st_class_1_ <-
st_class_1 %>%
blend_predictions()
Expand Down Expand Up @@ -337,5 +345,5 @@ reg_res_svm_5 <-
)
# save the environment in an .Rda and load in unit tests
save.image(file = "tests/testthat/helper_data.Rda", version = 2)
save.image(file = "tests/testthat/helper_data.Rda", version = 2, compress = "xz")
```
Binary file modified tests/testthat/helper_data.Rda
Binary file not shown.
28 changes: 14 additions & 14 deletions tests/testthat/out/model_stack_class.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@
Message: -- A stacked ensemble model -------------------------------------

Message:
Out of 20 possible candidate members, the ensemble retained 13.
Penalty: 1e-04.
Mixture: .
Out of 20 possible candidate members, the ensemble retained 22.
Penalty: 1e-05.
Mixture: 1.

Message: Across the 3 classes, there are an average of 4.33 coefficients per class.
Message: Across the 3 classes, there are an average of 7.33 coefficients per class.

Message:
The 10 highest weighted member classes are:

# A tibble: 10 x 4
member type weight class
<chr> <chr> <dbl> <chr>
1 .pred_mid_class_res_rf_1_04 rand_forest 27.5 low
2 .pred_mid_class_res_rf_1_01 rand_forest 16.1 mid
3 .pred_mid_class_res_rf_1_06 rand_forest 15.2 mid
4 .pred_full_class_res_rf_1_05 rand_forest 11.9 full
5 .pred_mid_class_res_rf_1_07 rand_forest 11.6 low
6 .pred_mid_class_res_rf_1_10 rand_forest 11.0 mid
7 .pred_mid_class_res_rf_1_08 rand_forest 7.20 low
8 .pred_mid_class_res_rf_1_02 rand_forest 5.92 low
9 .pred_full_class_res_rf_1_04 rand_forest 4.40 low
10 .pred_full_class_res_rf_1_10 rand_forest 4.17 mid
1 .pred_mid_class_res_rf_1_04 rand_forest 39.8 low
2 .pred_mid_class_res_rf_1_06 rand_forest 35.3 mid
3 .pred_mid_class_res_rf_1_09 rand_forest 23.5 mid
4 .pred_full_class_res_rf_1_05 rand_forest 21.7 full
5 .pred_full_class_res_rf_1_04 rand_forest 17.0 low
6 .pred_full_class_res_rf_1_09 rand_forest 16.6 mid
7 .pred_mid_class_res_rf_1_02 rand_forest 13.3 low
8 .pred_mid_class_res_rf_1_01 rand_forest 13.3 mid
9 .pred_mid_class_res_rf_1_10 rand_forest 11.9 low
10 .pred_mid_class_res_rf_1_03 rand_forest 11.4 low

Message:
Members have not yet been fitted with `fit_members()`.
Expand Down
28 changes: 14 additions & 14 deletions tests/testthat/out/model_stack_class_fit.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
Message: -- A stacked ensemble model -------------------------------------

Message:
Out of 20 possible candidate members, the ensemble retained 13.
Penalty: 1e-04.
Mixture: .
Out of 20 possible candidate members, the ensemble retained 22.
Penalty: 1e-05.
Mixture: 1.

Message: Across the 3 classes, there are an average of 4.33 coefficients per class.
Message: Across the 3 classes, there are an average of 7.33 coefficients per class.

Message:
The 10 highest weighted member classes are:

# A tibble: 10 x 4
member type weight class
<chr> <chr> <dbl> <chr>
1 .pred_mid_class_res_rf_1_04 rand_forest 27.5 low
2 .pred_mid_class_res_rf_1_01 rand_forest 16.1 mid
3 .pred_mid_class_res_rf_1_06 rand_forest 15.2 mid
4 .pred_full_class_res_rf_1_05 rand_forest 11.9 full
5 .pred_mid_class_res_rf_1_07 rand_forest 11.6 low
6 .pred_mid_class_res_rf_1_10 rand_forest 11.0 mid
7 .pred_mid_class_res_rf_1_08 rand_forest 7.20 low
8 .pred_mid_class_res_rf_1_02 rand_forest 5.92 low
9 .pred_full_class_res_rf_1_04 rand_forest 4.40 low
10 .pred_full_class_res_rf_1_10 rand_forest 4.17 mid
1 .pred_mid_class_res_rf_1_04 rand_forest 39.8 low
2 .pred_mid_class_res_rf_1_06 rand_forest 35.3 mid
3 .pred_mid_class_res_rf_1_09 rand_forest 23.5 mid
4 .pred_full_class_res_rf_1_05 rand_forest 21.7 full
5 .pred_full_class_res_rf_1_04 rand_forest 17.0 low
6 .pred_full_class_res_rf_1_09 rand_forest 16.6 mid
7 .pred_mid_class_res_rf_1_02 rand_forest 13.3 low
8 .pred_mid_class_res_rf_1_01 rand_forest 13.3 mid
9 .pred_mid_class_res_rf_1_10 rand_forest 11.9 low
10 .pred_mid_class_res_rf_1_03 rand_forest 11.4 low

16 changes: 7 additions & 9 deletions tests/testthat/out/model_stack_log.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@
Message: -- A stacked ensemble model -------------------------------------

Message:
Out of 10 possible candidate members, the ensemble retained 4.
Penalty: 1e-06.
Mixture: .
Out of 10 possible candidate members, the ensemble retained 2.
Penalty: 0.1.
Mixture: 1.

Message:
The 4 highest weighted member classes are:
The 2 highest weighted member classes are:

# A tibble: 4 x 3
# A tibble: 2 x 3
member type weight
<chr> <chr> <dbl>
1 .pred_yes_log_res_rf_1_09 rand_forest 4.64
2 .pred_yes_log_res_rf_1_03 rand_forest 1.25
3 .pred_yes_log_res_rf_1_06 rand_forest 0.687
4 .pred_yes_log_res_rf_1_05 rand_forest 0.467
1 .pred_yes_log_res_rf_1_03 rand_forest 3.54
2 .pred_yes_log_res_rf_1_06 rand_forest 0.0457

Message:
Members have not yet been fitted with `fit_members()`.
Expand Down
16 changes: 7 additions & 9 deletions tests/testthat/out/model_stack_log_fit.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
Message: -- A stacked ensemble model -------------------------------------

Message:
Out of 10 possible candidate members, the ensemble retained 4.
Penalty: 1e-06.
Mixture: .
Out of 10 possible candidate members, the ensemble retained 2.
Penalty: 0.1.
Mixture: 1.

Message:
The 4 highest weighted member classes are:
The 2 highest weighted member classes are:

# A tibble: 4 x 3
# A tibble: 2 x 3
member type weight
<chr> <chr> <dbl>
1 .pred_yes_log_res_rf_1_09 rand_forest 4.64
2 .pred_yes_log_res_rf_1_03 rand_forest 1.25
3 .pred_yes_log_res_rf_1_06 rand_forest 0.687
4 .pred_yes_log_res_rf_1_05 rand_forest 0.467
1 .pred_yes_log_res_rf_1_03 rand_forest 3.54
2 .pred_yes_log_res_rf_1_06 rand_forest 0.0457

13 changes: 7 additions & 6 deletions tests/testthat/out/model_stack_reg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
Message: -- A stacked ensemble model -------------------------------------

Message:
Out of 5 possible candidate members, the ensemble retained 2.
Out of 5 possible candidate members, the ensemble retained 3.
Penalty: 0.1.
Mixture: .
Mixture: 1.

Message:
The 2 highest weighted members are:
The 3 highest weighted members are:

# A tibble: 2 x 3
# A tibble: 3 x 3
member type weight
<chr> <chr> <dbl>
1 reg_res_svm_1_3 svm_rbf 1.19
2 reg_res_svm_1_1 svm_rbf 0.159
1 reg_res_svm_1_5 svm_rbf 2.88
2 reg_res_svm_1_3 svm_rbf 0.895
3 reg_res_svm_1_1 svm_rbf 0.410

Message:
Members have not yet been fitted with `fit_members()`.
Expand Down
13 changes: 7 additions & 6 deletions tests/testthat/out/model_stack_reg_fit.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
Message: -- A stacked ensemble model -------------------------------------

Message:
Out of 5 possible candidate members, the ensemble retained 2.
Out of 5 possible candidate members, the ensemble retained 3.
Penalty: 0.1.
Mixture: .
Mixture: 1.

Message:
The 2 highest weighted members are:
The 3 highest weighted members are:

# A tibble: 2 x 3
# A tibble: 3 x 3
member type weight
<chr> <chr> <dbl>
1 reg_res_svm_1_3 svm_rbf 1.19
2 reg_res_svm_1_1 svm_rbf 0.159
1 reg_res_svm_1_5 svm_rbf 2.88
2 reg_res_svm_1_3 svm_rbf 0.895
3 reg_res_svm_1_1 svm_rbf 0.410

Loading

0 comments on commit 35cb85c

Please sign in to comment.