Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_candidates() takes a long time to collect predictions #90

Closed
bensoltoff opened this issue Sep 13, 2021 · 5 comments
Closed

add_candidates() takes a long time to collect predictions #90

bensoltoff opened this issue Sep 13, 2021 · 5 comments

Comments

@bensoltoff
Copy link

I was recently trying to use stacks to create a stacked model for a multinomial regression problem with 20,000 observations, 20 classes for the outcome of interest, and four sets of tuned candidate models trained via 5-fold cross-validation. Overall there should have been about 90 candidate models available. However I could never get past the add_candidates() step. The computer was running for hours and hours and never even got to blending the predictions.

I believe the issue is caused by how add_candidates() relies on tune::collect_predictions() to extract the assessment set predictions for each candidate model.

stacks/R/add_candidates.R

Lines 477 to 478 in 0985737

res <- tune::collect_predictions(x, summarize = TRUE) %>%
dplyr::rename_with(make.names, .cols = dplyr::starts_with(".pred"))

When summarize = TRUE, collect_predictions() uses the internal prob_summarize() function to generate average probabilities for each observation. Based on my understanding of collect_predictions(), this is necessary when there is more than one assessment set prediction for a single observation in a single candidate model. However in the case of standard cross-validation, will there be more than one assessment set prediction per observation per candidate model?

Consider the example below. The output of collect_predictions() is essentially the same whether or not one summarizes the predictions. The only significant difference I see is the id column indicating the fold is dropped, and the column order is adjusted. Otherwise they contain the same rows and same predicted probabilities, but collection with summarize is 18 times slower.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(stacks)
library(palmerpenguins)

set.seed(123)

# create folds
penguins_folds <- vfold_cv(data = penguins)

# create basic model and tune over penalty
tuned <- recipe(species ~ bill_length_mm + bill_depth_mm +
  flipper_length_mm, data = penguins) %>%
  step_impute_mean(all_numeric_predictors()) %>%
  workflow(multinom_reg(engine = "glmnet", penalty = tune(), mixture = 0)) %>%
  tune_grid(penguins_folds,
    grid = 10,
    control = control_stack_grid()
  )

# collect_predictions
tuned_bm <- microbenchmark::microbenchmark(
  summarize_true = collect_predictions(tuned, summarize = TRUE),
  summarize_false = collect_predictions(tuned, summarize = FALSE)
)

# collection with summarize takes much longer
tuned_bm
#> Unit: milliseconds
#>             expr       min        lq      mean    median        uq       max
#>   summarize_true 336.65333 345.39742 354.68878 350.88546 358.25427 471.81579
#>  summarize_false  17.45312  17.68013  18.04854  17.83158  18.01736  22.87325
#>  neval cld
#>    100   b
#>    100  a
autoplot(tuned_bm)
#> Coordinate system already present. Adding new coordinate system, which will replace the existing one.

# but for non-bootstrapped resampling, is this necessary?
collect_predictions(tuned, summarize = FALSE)
#> # A tibble: 3,440 × 9
#>    id     .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row  penalty .pred_class
#>    <chr>         <dbl>           <dbl>        <dbl> <int>    <dbl> <fct>      
#>  1 Fold01        0.924          0.0657      0.0103     12 3.49e-10 Adelie     
#>  2 Fold01        0.946          0.0514      0.00310    22 3.49e-10 Adelie     
#>  3 Fold01        0.927          0.0631      0.00972    24 3.49e-10 Adelie     
#>  4 Fold01        0.828          0.133       0.0391     42 3.49e-10 Adelie     
#>  5 Fold01        0.827          0.166       0.00685    47 3.49e-10 Adelie     
#>  6 Fold01        0.940          0.0296      0.0303     69 3.49e-10 Adelie     
#>  7 Fold01        0.965          0.0223      0.0129     75 3.49e-10 Adelie     
#>  8 Fold01        0.926          0.0374      0.0368     79 3.49e-10 Adelie     
#>  9 Fold01        0.920          0.0688      0.0108     90 3.49e-10 Adelie     
#> 10 Fold01        0.856          0.108       0.0356     98 3.49e-10 Adelie     
#> # … with 3,430 more rows, and 2 more variables: species <fct>, .config <chr>
collect_predictions(tuned, summarize = TRUE)
#> # A tibble: 3,440 × 8
#>     .row  penalty species .config      .pred_Adelie .pred_Chinstrap .pred_Gentoo
#>    <int>    <dbl> <fct>   <chr>               <dbl>           <dbl>        <dbl>
#>  1     1 3.49e-10 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  2     1 1.24e- 9 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  3     1 5.03e- 8 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  4     1 8.57e- 7 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  5     1 1.00e- 6 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  6     1 9.30e- 5 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  7     1 4.31e- 4 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  8     1 5.25e- 3 Adelie  Preprocesso…        0.910          0.0851      0.00504
#>  9     1 4.90e- 2 Adelie  Preprocesso…        0.903          0.0903      0.00630
#> 10     1 8.81e- 1 Adelie  Preprocesso…        0.680          0.181       0.138  
#> # … with 3,430 more rows, and 1 more variable: .pred_class <fct>

Created on 2021-09-13 by the reprex package (v2.0.1)

I understand it is necessary if one uses a bootstrapping resampling technique, but Is it necessary to use summarize = TRUE if the resampling is cross-validation? Can add_candidates() be written in a way to detect the form of resampling and only use summarize = TRUE if bootstrapping resampling is used?

@simonpcouch
Copy link
Collaborator

Thanks for the thoughtful issue! This seems like a good idea to me, especially re: the benchmarking you've done above. @topepo, do you have any reservations here before I start working on a PR? Is there an easy way to note when summarize is not needed?

@simonpcouch
Copy link
Collaborator

@topepo, pinging before I give this a go. Holler if you have any reservations here and/or tips for the conditional determining the value of the summarize argument!

@bensoltoff
Copy link
Author

Would it be as simple as examining the class of the resample data frame to see if it is type bootstrap or some variation of cross-validation? See below.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(palmerpenguins)

# cross-validation - no repeats
vfold_cv(data = penguins) %>% class()
#> [1] "vfold_cv"   "rset"       "tbl_df"     "tbl"        "data.frame"

# cross-validation - with repeats
vfold_cv(data = penguins, repeats = 2) %>% class()
#> [1] "vfold_cv"   "rset"       "tbl_df"     "tbl"        "data.frame"

# monte-carlo cv
mc_cv(data = penguins) %>% class()
#> [1] "mc_cv"      "rset"       "tbl_df"     "tbl"        "data.frame"

# bootstrap
bootstraps(data = penguins) %>% class()
#> [1] "bootstraps" "rset"       "tbl_df"     "tbl"        "data.frame"

Created on 2021-11-20 by the reprex package (v2.0.1)

I'll admit I might not have fully thought it through if different forms of cross-validation have more than one assessment set prediction for a given observation.

@simonpcouch
Copy link
Collaborator

@bensoltoff Ideally it would be this simple :) I just want to be absolutely sure that I don't falsely assert that it's okay not to summarize predictions in any situation in which stacks actually ought to.

I don't think bootstraps or mc_cv are fair game here since each row in the original data has multiple holdout predictions--a quick reprex using your tune_grid example:

library(stacks)
library(palmerpenguins)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

same_nrow <- function(sampler) {
  folds <- sampler(data = penguins)
  
  tuned <- recipe(species ~ bill_length_mm + bill_depth_mm +
                    flipper_length_mm, data = penguins) %>%
    step_impute_mean(all_numeric_predictors()) %>%
    workflow(multinom_reg(engine = "glmnet", penalty = tune(), mixture = 0)) %>%
    tune_grid(folds,
              grid = 10,
              control = control_stack_grid()
    )
  
  nrow(collect_predictions(tuned, summarize = TRUE)) == 
    nrow(collect_predictions(tuned, summarize = FALSE))
}

lapply(
  c(vfold_cv, bootstraps, mc_cv), 
  same_nrow
)
#> [[1]]
#> [1] TRUE
#> 
#> [[2]]
#> [1] FALSE
#> 
#> [[3]]
#> [1] FALSE

Created on 2021-12-29 by the reprex package (v2.0.0)

vfold_cv is also not always fair game--a quick example working from the code used to generate the example model objects in the package:

library(stacks)
library(palmerpenguins)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(tidyverse)

tree_frogs_reg <- 
  tree_frogs %>% 
  filter(!is.na(latency)) %>%
  select(-clutch, -hatched)
  
reg_folds <- rsample::vfold_cv(tree_frogs_reg, v = 5, repeats = 2)
  
tree_frogs_reg_rec <- 
  recipes::recipe(latency ~ ., data = tree_frogs_reg) %>%
  recipes::step_dummy(recipes::all_nominal()) %>%
  recipes::step_zv(recipes::all_predictors())

svm_spec <- 
  parsnip::svm_rbf(
    cost = tune::tune(), 
    rbf_sigma = tune::tune()
  ) %>%
  parsnip::set_engine("kernlab") %>%
  parsnip::set_mode("regression")

reg_wf_svm <- 
  workflows::workflow() %>%
  workflows::add_model(svm_spec) %>%
  workflows::add_recipe(tree_frogs_reg_rec)

reg_res_svm <- 
  tune::tune_grid(
    object = reg_wf_svm,
    resamples = reg_folds, 
    grid = 5,
    control = control_stack_grid()
  )

nrow(collect_predictions(reg_res_svm, summarize = TRUE)) == 
  nrow(collect_predictions(reg_res_svm, summarize = FALSE))
#> [1] FALSE

Created on 2021-12-29 by the reprex package (v2.0.0)

In all, I don't think I feel confident to write solid logic here. Very much open to input if any folks from tidymodels core know of straightforward rules here.

@github-actions
Copy link

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jan 13, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants