diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 7fd85ab3..7fbd41f5 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -23,24 +23,25 @@ jobs: config: - {os: macOS-latest, r: 'release'} - {os: windows-latest, r: 'release'} - - {os: windows-latest, r: '3.6'} - - {os: ubuntu-16.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest", http-user-agent: "R/4.0.0 (ubuntu-16.04) R (4.0.0 x86_64-pc-linux-gnu x86_64 linux-gnu) on GitHub Actions" } - - {os: ubuntu-16.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"} + - {os: windows-latest, r: 'oldrel-1'} + - {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'} + - {os: ubuntu-18.04, r: 'release'} env: R_REMOTES_NO_ERRORS_FROM_WARNINGS: true - RSPM: ${{ matrix.config.rspm }} GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + R_KEEP_PKG_SOURCE: yes steps: - uses: actions/checkout@v2 + + - uses: r-lib/actions/setup-pandoc@v1 - - uses: r-lib/actions/setup-r@master + - uses: r-lib/actions/setup-r@v1 with: r-version: ${{ matrix.config.r }} http-user-agent: ${{ matrix.config.http-user-agent }} - - - uses: r-lib/actions/setup-pandoc@master + use-public-rspm: true - name: Query dependencies run: | @@ -70,6 +71,7 @@ jobs: remotes::install_cran("vctrs") remotes::install_cran("parsnip") remotes::install_cran("tune") + remotes::install_cran("kknn") remotes::install_deps(dependencies = TRUE) remotes::install_cran("rcmdcheck") shell: Rscript {0} diff --git a/DESCRIPTION b/DESCRIPTION index ee4f468f..86e5a2a0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -24,7 +24,7 @@ BugReports: https://github.com/tidymodels/stacks/issues Depends: R (>= 2.10) Imports: - tune (>= 0.1.2.9000), + tune (>= 0.1.3), dplyr (>= 1.0.0), rlang (>= 0.4.0), tibble (>= 2.1.3), @@ -32,8 +32,8 @@ Imports: parsnip (>= 0.0.4), workflows (>= 0.2.2), recipes (>= 0.1.15), - rsample (>= 0.0.9), - workflowsets (>= 0.0.0.9001), + rsample (>= 0.1.1), + workflowsets (>= 0.1.0), butcher (>= 0.1.3), yardstick, tidyr, @@ -58,5 +58,5 @@ Suggests: Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.1.1.9001 +RoxygenNote: 7.1.2 VignetteBuilder: knitr diff --git a/R/collect_parameters.R b/R/collect_parameters.R index 2509a67e..1a35e263 100644 --- a/R/collect_parameters.R +++ b/R/collect_parameters.R @@ -66,7 +66,7 @@ collect_parameters.data_stack <- function(stack, candidates, ...) { attributes(stack)$model_metrics, candidates, attributes(stack)$model_defs, - stack + stack = stack ) } @@ -79,7 +79,7 @@ collect_parameters.model_stack <- function(stack, candidates, ...) { candidates, stack$model_defs, stack$coefs, - stack + stack = stack ) } diff --git a/data/class_folds.rda b/data/class_folds.rda index c1d7d285..fe370be9 100644 Binary files a/data/class_folds.rda and b/data/class_folds.rda differ diff --git a/data/class_res_nn.rda b/data/class_res_nn.rda index 193c0b12..f8471cff 100644 Binary files a/data/class_res_nn.rda and b/data/class_res_nn.rda differ diff --git a/data/class_res_rf.rda b/data/class_res_rf.rda index 960df91f..40277817 100644 Binary files a/data/class_res_rf.rda and b/data/class_res_rf.rda differ diff --git a/data/log_res_nn.rda b/data/log_res_nn.rda index f4900d4e..91b22b5a 100644 Binary files a/data/log_res_nn.rda and b/data/log_res_nn.rda differ diff --git a/data/log_res_rf.rda b/data/log_res_rf.rda index ba771388..9781aa5b 100644 Binary files a/data/log_res_rf.rda and b/data/log_res_rf.rda differ diff --git a/data/reg_folds.rda b/data/reg_folds.rda index 099b9d45..108b6a10 100644 Binary files a/data/reg_folds.rda and b/data/reg_folds.rda differ diff --git a/data/reg_res_lr.rda b/data/reg_res_lr.rda index de2b17ba..8e90350d 100644 Binary files a/data/reg_res_lr.rda and b/data/reg_res_lr.rda differ diff --git a/data/reg_res_sp.rda b/data/reg_res_sp.rda index 0a320d4d..d13386a0 100644 Binary files a/data/reg_res_sp.rda and b/data/reg_res_sp.rda differ diff --git a/data/reg_res_svm.rda b/data/reg_res_svm.rda index 97a29fb3..70ddcbee 100644 Binary files a/data/reg_res_svm.rda and b/data/reg_res_svm.rda differ diff --git a/data/tree_frogs_class_test.rda b/data/tree_frogs_class_test.rda index 7ad1ac66..ef97ce6f 100644 Binary files a/data/tree_frogs_class_test.rda and b/data/tree_frogs_class_test.rda differ diff --git a/data/tree_frogs_reg_test.rda b/data/tree_frogs_reg_test.rda index edd7fe36..4acc925c 100644 Binary files a/data/tree_frogs_reg_test.rda and b/data/tree_frogs_reg_test.rda differ diff --git a/man/example_data.Rd b/man/example_data.Rd index aafd3ca9..a65801a0 100644 --- a/man/example_data.Rd +++ b/man/example_data.Rd @@ -101,7 +101,7 @@ to the stimulus) using most all of the other variables as predictors. The source code for generating these objects is given below. -\if{html}{\out{
}}\preformatted{# setup: packages, data, resample, basic recipe ------------------------ +\if{html}{\out{
}}\preformatted{# setup: packages, data, resample, basic recipe ------------------------ library(stacks) library(tune) library(rsample) diff --git a/man/stacks_description.Rd b/man/stacks_description.Rd index 2a87a6fb..556521ca 100644 --- a/man/stacks_description.Rd +++ b/man/stacks_description.Rd @@ -9,11 +9,7 @@ \description{ \if{html}{\figure{logo.png}{options: align='right' alt='logo' width='120'}} -Model stacking is an ensemble technique - that involves training a model to combine the outputs of many - diverse statistical models, and has been shown to improve - predictive performance in a variety of settings. 'stacks' - implements a grammar for 'tidymodels'-aligned model stacking. +Model stacking is an ensemble technique that involves training a model to combine the outputs of many diverse statistical models, and has been shown to improve predictive performance in a variety of settings. 'stacks' implements a grammar for 'tidymodels'-aligned model stacking. } \seealso{ Useful links: diff --git a/tests/testthat/helper_data.Rda b/tests/testthat/helper_data.Rda index 95bc0176..be7a9c6d 100644 Binary files a/tests/testthat/helper_data.Rda and b/tests/testthat/helper_data.Rda differ diff --git a/tests/testthat/out/data_stack_class.txt b/tests/testthat/out/data_stack_class.txt index 5b969fda..23229636 100644 --- a/tests/testthat/out/data_stack_class.txt +++ b/tests/testthat/out/data_stack_class.txt @@ -1,5 +1,5 @@ > st_class_1 -# A data stack with 1 model definition and 10 candidate members: -# class_res_rf: 10 model configurations +# A data stack with 1 model definition and 9.66666666666667 candidate members: +# class_res_rf: 9.66666666666667 model configurations # Outcome: reflex (factor) diff --git a/tests/testthat/out/model_stack_class.txt b/tests/testthat/out/model_stack_class.txt index 57c6ba60..618f1b6b 100644 --- a/tests/testthat/out/model_stack_class.txt +++ b/tests/testthat/out/model_stack_class.txt @@ -2,28 +2,24 @@ Message: -- A stacked ensemble model ------------------------------------- Message: -Out of 20 possible candidate members, the ensemble retained 22. -Penalty: 1e-05. +Out of 19 possible candidate members, the ensemble retained 6. +Penalty: 0.1. Mixture: 1. -Message: Across the 3 classes, there are an average of 7.33 coefficients per class. +Message: Across the 3 classes, there are an average of 3 coefficients per class. Message: -The 10 highest weighted member classes are: +The 6 highest weighted member classes are: -# A tibble: 10 x 4 - member type weight class - - 1 .pred_mid_class_res_rf_1_04 rand_forest 39.8 low - 2 .pred_mid_class_res_rf_1_06 rand_forest 35.3 mid - 3 .pred_mid_class_res_rf_1_09 rand_forest 23.5 mid - 4 .pred_full_class_res_rf_1_05 rand_forest 21.7 full - 5 .pred_full_class_res_rf_1_04 rand_forest 17.0 low - 6 .pred_full_class_res_rf_1_09 rand_forest 16.6 mid - 7 .pred_mid_class_res_rf_1_02 rand_forest 13.3 low - 8 .pred_mid_class_res_rf_1_01 rand_forest 13.3 mid - 9 .pred_mid_class_res_rf_1_10 rand_forest 11.9 low -10 .pred_mid_class_res_rf_1_03 rand_forest 11.4 low +# A tibble: 6 x 4 + member type weight class + +1 .pred_full_class_res_rf_1_05 rand_forest 3.75 full +2 .pred_mid_class_res_rf_1_06 rand_forest 0.674 mid +3 .pred_full_class_res_rf_1_07 rand_forest 0.411 full +4 .pred_full_class_res_rf_1_01 rand_forest 0.0957 full +5 .pred_full_class_res_rf_1_06 rand_forest 0.0193 full +6 .pred_full_class_res_rf_1_04 rand_forest 0.0110 full Message: Members have not yet been fitted with `fit_members()`. diff --git a/tests/testthat/out/model_stack_class_fit.txt b/tests/testthat/out/model_stack_class_fit.txt index caeea28e..8bcf81c6 100644 --- a/tests/testthat/out/model_stack_class_fit.txt +++ b/tests/testthat/out/model_stack_class_fit.txt @@ -2,26 +2,22 @@ Message: -- A stacked ensemble model ------------------------------------- Message: -Out of 20 possible candidate members, the ensemble retained 22. -Penalty: 1e-05. +Out of 19 possible candidate members, the ensemble retained 6. +Penalty: 0.1. Mixture: 1. -Message: Across the 3 classes, there are an average of 7.33 coefficients per class. +Message: Across the 3 classes, there are an average of 3 coefficients per class. Message: -The 10 highest weighted member classes are: +The 6 highest weighted member classes are: -# A tibble: 10 x 4 - member type weight class - - 1 .pred_mid_class_res_rf_1_04 rand_forest 39.8 low - 2 .pred_mid_class_res_rf_1_06 rand_forest 35.3 mid - 3 .pred_mid_class_res_rf_1_09 rand_forest 23.5 mid - 4 .pred_full_class_res_rf_1_05 rand_forest 21.7 full - 5 .pred_full_class_res_rf_1_04 rand_forest 17.0 low - 6 .pred_full_class_res_rf_1_09 rand_forest 16.6 mid - 7 .pred_mid_class_res_rf_1_02 rand_forest 13.3 low - 8 .pred_mid_class_res_rf_1_01 rand_forest 13.3 mid - 9 .pred_mid_class_res_rf_1_10 rand_forest 11.9 low -10 .pred_mid_class_res_rf_1_03 rand_forest 11.4 low +# A tibble: 6 x 4 + member type weight class + +1 .pred_full_class_res_rf_1_05 rand_forest 3.75 full +2 .pred_mid_class_res_rf_1_06 rand_forest 0.674 mid +3 .pred_full_class_res_rf_1_07 rand_forest 0.411 full +4 .pred_full_class_res_rf_1_01 rand_forest 0.0957 full +5 .pred_full_class_res_rf_1_06 rand_forest 0.0193 full +6 .pred_full_class_res_rf_1_04 rand_forest 0.0110 full diff --git a/tests/testthat/out/model_stack_log.txt b/tests/testthat/out/model_stack_log.txt index bcee8f77..e8163c35 100644 --- a/tests/testthat/out/model_stack_log.txt +++ b/tests/testthat/out/model_stack_log.txt @@ -2,18 +2,19 @@ Message: -- A stacked ensemble model ------------------------------------- Message: -Out of 10 possible candidate members, the ensemble retained 2. -Penalty: 0.1. +Out of 10 possible candidate members, the ensemble retained 3. +Penalty: 1e-05. Mixture: 1. Message: -The 2 highest weighted member classes are: +The 3 highest weighted member classes are: -# A tibble: 2 x 3 +# A tibble: 3 x 3 member type weight -1 .pred_yes_log_res_rf_1_03 rand_forest 3.54 -2 .pred_yes_log_res_rf_1_06 rand_forest 0.0457 +1 .pred_yes_log_res_rf_1_05 rand_forest 3.56 +2 .pred_yes_log_res_rf_1_02 rand_forest 3.22 +3 .pred_yes_log_res_rf_1_09 rand_forest 0.226 Message: Members have not yet been fitted with `fit_members()`. diff --git a/tests/testthat/out/model_stack_log_fit.txt b/tests/testthat/out/model_stack_log_fit.txt index 543dee23..37d00985 100644 --- a/tests/testthat/out/model_stack_log_fit.txt +++ b/tests/testthat/out/model_stack_log_fit.txt @@ -2,16 +2,17 @@ Message: -- A stacked ensemble model ------------------------------------- Message: -Out of 10 possible candidate members, the ensemble retained 2. -Penalty: 0.1. +Out of 10 possible candidate members, the ensemble retained 3. +Penalty: 1e-05. Mixture: 1. Message: -The 2 highest weighted member classes are: +The 3 highest weighted member classes are: -# A tibble: 2 x 3 +# A tibble: 3 x 3 member type weight -1 .pred_yes_log_res_rf_1_03 rand_forest 3.54 -2 .pred_yes_log_res_rf_1_06 rand_forest 0.0457 +1 .pred_yes_log_res_rf_1_05 rand_forest 3.56 +2 .pred_yes_log_res_rf_1_02 rand_forest 3.22 +3 .pred_yes_log_res_rf_1_09 rand_forest 0.226 diff --git a/tests/testthat/out/model_stack_reg.txt b/tests/testthat/out/model_stack_reg.txt index b3c80d11..c533716f 100644 --- a/tests/testthat/out/model_stack_reg.txt +++ b/tests/testthat/out/model_stack_reg.txt @@ -2,19 +2,18 @@ Message: -- A stacked ensemble model ------------------------------------- Message: -Out of 5 possible candidate members, the ensemble retained 3. +Out of 5 possible candidate members, the ensemble retained 2. Penalty: 0.1. Mixture: 1. Message: -The 3 highest weighted members are: +The 2 highest weighted members are: -# A tibble: 3 x 3 +# A tibble: 2 x 3 member type weight -1 reg_res_svm_1_5 svm_rbf 2.88 -2 reg_res_svm_1_3 svm_rbf 0.895 -3 reg_res_svm_1_1 svm_rbf 0.410 +1 reg_res_svm_1_3 svm_rbf 1.26 +2 reg_res_svm_1_2 svm_rbf 0.136 Message: Members have not yet been fitted with `fit_members()`. diff --git a/tests/testthat/out/model_stack_reg_fit.txt b/tests/testthat/out/model_stack_reg_fit.txt index bb5ea420..63307382 100644 --- a/tests/testthat/out/model_stack_reg_fit.txt +++ b/tests/testthat/out/model_stack_reg_fit.txt @@ -2,17 +2,16 @@ Message: -- A stacked ensemble model ------------------------------------- Message: -Out of 5 possible candidate members, the ensemble retained 3. +Out of 5 possible candidate members, the ensemble retained 2. Penalty: 0.1. Mixture: 1. Message: -The 3 highest weighted members are: +The 2 highest weighted members are: -# A tibble: 3 x 3 +# A tibble: 2 x 3 member type weight -1 reg_res_svm_1_5 svm_rbf 2.88 -2 reg_res_svm_1_3 svm_rbf 0.895 -3 reg_res_svm_1_1 svm_rbf 0.410 +1 reg_res_svm_1_3 svm_rbf 1.26 +2 reg_res_svm_1_2 svm_rbf 0.136 diff --git a/tests/testthat/test_collect_parameters.R b/tests/testthat/test_collect_parameters.R index ba0395b8..a3b106bd 100644 --- a/tests/testthat/test_collect_parameters.R +++ b/tests/testthat/test_collect_parameters.R @@ -55,7 +55,7 @@ test_that("collect_parameters on a data stack works (regression)", { expect_equal(nrow(res), 5) expect_equal(ncol(res2), 2) - expect_equal(nrow(res2), 9) + expect_equal(nrow(res2), 10) expect_equal(ncol(res3), 1) expect_equal(nrow(res3), 1) @@ -73,7 +73,7 @@ test_that("collect_parameters on a model stack works (regression)", { expect_equal(nrow(res), 5) expect_equal(ncol(res2), 3) - expect_equal(nrow(res2), 9) + expect_equal(nrow(res2), 10) expect_true( all( @@ -100,5 +100,5 @@ test_that("collect_parameters works (classification)", { expect_equal(nrow(res), 10) expect_equal(ncol(res2), 6) - expect_equal(nrow(res2), 60) + expect_equal(nrow(res2), 57) })