-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add_candidates() takes a long time to collect predictions #90
Comments
Thanks for the thoughtful issue! This seems like a good idea to me, especially re: the benchmarking you've done above. @topepo, do you have any reservations here before I start working on a PR? Is there an easy way to note when |
@topepo, pinging before I give this a go. Holler if you have any reservations here and/or tips for the conditional determining the value of the |
Would it be as simple as examining the class of the resample data frame to see if it is type bootstrap or some variation of cross-validation? See below. library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(palmerpenguins)
# cross-validation - no repeats
vfold_cv(data = penguins) %>% class()
#> [1] "vfold_cv" "rset" "tbl_df" "tbl" "data.frame"
# cross-validation - with repeats
vfold_cv(data = penguins, repeats = 2) %>% class()
#> [1] "vfold_cv" "rset" "tbl_df" "tbl" "data.frame"
# monte-carlo cv
mc_cv(data = penguins) %>% class()
#> [1] "mc_cv" "rset" "tbl_df" "tbl" "data.frame"
# bootstrap
bootstraps(data = penguins) %>% class()
#> [1] "bootstraps" "rset" "tbl_df" "tbl" "data.frame" Created on 2021-11-20 by the reprex package (v2.0.1) I'll admit I might not have fully thought it through if different forms of cross-validation have more than one assessment set prediction for a given observation. |
@bensoltoff Ideally it would be this simple :) I just want to be absolutely sure that I don't falsely assert that it's okay not to summarize predictions in any situation in which I don't think library(stacks)
library(palmerpenguins)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
same_nrow <- function(sampler) {
folds <- sampler(data = penguins)
tuned <- recipe(species ~ bill_length_mm + bill_depth_mm +
flipper_length_mm, data = penguins) %>%
step_impute_mean(all_numeric_predictors()) %>%
workflow(multinom_reg(engine = "glmnet", penalty = tune(), mixture = 0)) %>%
tune_grid(folds,
grid = 10,
control = control_stack_grid()
)
nrow(collect_predictions(tuned, summarize = TRUE)) ==
nrow(collect_predictions(tuned, summarize = FALSE))
}
lapply(
c(vfold_cv, bootstraps, mc_cv),
same_nrow
)
#> [[1]]
#> [1] TRUE
#>
#> [[2]]
#> [1] FALSE
#>
#> [[3]]
#> [1] FALSE Created on 2021-12-29 by the reprex package (v2.0.0)
library(stacks)
library(palmerpenguins)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(tidyverse)
tree_frogs_reg <-
tree_frogs %>%
filter(!is.na(latency)) %>%
select(-clutch, -hatched)
reg_folds <- rsample::vfold_cv(tree_frogs_reg, v = 5, repeats = 2)
tree_frogs_reg_rec <-
recipes::recipe(latency ~ ., data = tree_frogs_reg) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_zv(recipes::all_predictors())
svm_spec <-
parsnip::svm_rbf(
cost = tune::tune(),
rbf_sigma = tune::tune()
) %>%
parsnip::set_engine("kernlab") %>%
parsnip::set_mode("regression")
reg_wf_svm <-
workflows::workflow() %>%
workflows::add_model(svm_spec) %>%
workflows::add_recipe(tree_frogs_reg_rec)
reg_res_svm <-
tune::tune_grid(
object = reg_wf_svm,
resamples = reg_folds,
grid = 5,
control = control_stack_grid()
)
nrow(collect_predictions(reg_res_svm, summarize = TRUE)) ==
nrow(collect_predictions(reg_res_svm, summarize = FALSE))
#> [1] FALSE Created on 2021-12-29 by the reprex package (v2.0.0) In all, I don't think I feel confident to write solid logic here. Very much open to input if any folks from tidymodels core know of straightforward rules here. |
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
I was recently trying to use
stacks
to create a stacked model for a multinomial regression problem with 20,000 observations, 20 classes for the outcome of interest, and four sets of tuned candidate models trained via 5-fold cross-validation. Overall there should have been about 90 candidate models available. However I could never get past theadd_candidates()
step. The computer was running for hours and hours and never even got to blending the predictions.I believe the issue is caused by how
add_candidates()
relies ontune::collect_predictions()
to extract the assessment set predictions for each candidate model.stacks/R/add_candidates.R
Lines 477 to 478 in 0985737
When
summarize = TRUE
,collect_predictions()
uses the internalprob_summarize()
function to generate average probabilities for each observation. Based on my understanding ofcollect_predictions()
, this is necessary when there is more than one assessment set prediction for a single observation in a single candidate model. However in the case of standard cross-validation, will there be more than one assessment set prediction per observation per candidate model?Consider the example below. The output of
collect_predictions()
is essentially the same whether or not one summarizes the predictions. The only significant difference I see is theid
column indicating the fold is dropped, and the column order is adjusted. Otherwise they contain the same rows and same predicted probabilities, but collection with summarize is 18 times slower.Created on 2021-09-13 by the reprex package (v2.0.1)
I understand it is necessary if one uses a bootstrapping resampling technique, but Is it necessary to use
summarize = TRUE
if the resampling is cross-validation? Canadd_candidates()
be written in a way to detect the form of resampling and only usesummarize = TRUE
if bootstrapping resampling is used?The text was updated successfully, but these errors were encountered: