diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml
index 7fd85ab3..7fbd41f5 100644
--- a/.github/workflows/R-CMD-check.yaml
+++ b/.github/workflows/R-CMD-check.yaml
@@ -23,24 +23,25 @@ jobs:
config:
- {os: macOS-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- - {os: windows-latest, r: '3.6'}
- - {os: ubuntu-16.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest", http-user-agent: "R/4.0.0 (ubuntu-16.04) R (4.0.0 x86_64-pc-linux-gnu x86_64 linux-gnu) on GitHub Actions" }
- - {os: ubuntu-16.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"}
+ - {os: windows-latest, r: 'oldrel-1'}
+ - {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'}
+ - {os: ubuntu-18.04, r: 'release'}
env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
- RSPM: ${{ matrix.config.rspm }}
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
+ R_KEEP_PKG_SOURCE: yes
steps:
- uses: actions/checkout@v2
+
+ - uses: r-lib/actions/setup-pandoc@v1
- - uses: r-lib/actions/setup-r@master
+ - uses: r-lib/actions/setup-r@v1
with:
r-version: ${{ matrix.config.r }}
http-user-agent: ${{ matrix.config.http-user-agent }}
-
- - uses: r-lib/actions/setup-pandoc@master
+ use-public-rspm: true
- name: Query dependencies
run: |
@@ -70,6 +71,7 @@ jobs:
remotes::install_cran("vctrs")
remotes::install_cran("parsnip")
remotes::install_cran("tune")
+ remotes::install_cran("kknn")
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("rcmdcheck")
shell: Rscript {0}
diff --git a/DESCRIPTION b/DESCRIPTION
index ee4f468f..86e5a2a0 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -24,7 +24,7 @@ BugReports: https://github.com/tidymodels/stacks/issues
Depends:
R (>= 2.10)
Imports:
- tune (>= 0.1.2.9000),
+ tune (>= 0.1.3),
dplyr (>= 1.0.0),
rlang (>= 0.4.0),
tibble (>= 2.1.3),
@@ -32,8 +32,8 @@ Imports:
parsnip (>= 0.0.4),
workflows (>= 0.2.2),
recipes (>= 0.1.15),
- rsample (>= 0.0.9),
- workflowsets (>= 0.0.0.9001),
+ rsample (>= 0.1.1),
+ workflowsets (>= 0.1.0),
butcher (>= 0.1.3),
yardstick,
tidyr,
@@ -58,5 +58,5 @@ Suggests:
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
-RoxygenNote: 7.1.1.9001
+RoxygenNote: 7.1.2
VignetteBuilder: knitr
diff --git a/R/collect_parameters.R b/R/collect_parameters.R
index 2509a67e..1a35e263 100644
--- a/R/collect_parameters.R
+++ b/R/collect_parameters.R
@@ -66,7 +66,7 @@ collect_parameters.data_stack <- function(stack, candidates, ...) {
attributes(stack)$model_metrics,
candidates,
attributes(stack)$model_defs,
- stack
+ stack = stack
)
}
@@ -79,7 +79,7 @@ collect_parameters.model_stack <- function(stack, candidates, ...) {
candidates,
stack$model_defs,
stack$coefs,
- stack
+ stack = stack
)
}
diff --git a/data/class_folds.rda b/data/class_folds.rda
index c1d7d285..fe370be9 100644
Binary files a/data/class_folds.rda and b/data/class_folds.rda differ
diff --git a/data/class_res_nn.rda b/data/class_res_nn.rda
index 193c0b12..f8471cff 100644
Binary files a/data/class_res_nn.rda and b/data/class_res_nn.rda differ
diff --git a/data/class_res_rf.rda b/data/class_res_rf.rda
index 960df91f..40277817 100644
Binary files a/data/class_res_rf.rda and b/data/class_res_rf.rda differ
diff --git a/data/log_res_nn.rda b/data/log_res_nn.rda
index f4900d4e..91b22b5a 100644
Binary files a/data/log_res_nn.rda and b/data/log_res_nn.rda differ
diff --git a/data/log_res_rf.rda b/data/log_res_rf.rda
index ba771388..9781aa5b 100644
Binary files a/data/log_res_rf.rda and b/data/log_res_rf.rda differ
diff --git a/data/reg_folds.rda b/data/reg_folds.rda
index 099b9d45..108b6a10 100644
Binary files a/data/reg_folds.rda and b/data/reg_folds.rda differ
diff --git a/data/reg_res_lr.rda b/data/reg_res_lr.rda
index de2b17ba..8e90350d 100644
Binary files a/data/reg_res_lr.rda and b/data/reg_res_lr.rda differ
diff --git a/data/reg_res_sp.rda b/data/reg_res_sp.rda
index 0a320d4d..d13386a0 100644
Binary files a/data/reg_res_sp.rda and b/data/reg_res_sp.rda differ
diff --git a/data/reg_res_svm.rda b/data/reg_res_svm.rda
index 97a29fb3..70ddcbee 100644
Binary files a/data/reg_res_svm.rda and b/data/reg_res_svm.rda differ
diff --git a/data/tree_frogs_class_test.rda b/data/tree_frogs_class_test.rda
index 7ad1ac66..ef97ce6f 100644
Binary files a/data/tree_frogs_class_test.rda and b/data/tree_frogs_class_test.rda differ
diff --git a/data/tree_frogs_reg_test.rda b/data/tree_frogs_reg_test.rda
index edd7fe36..4acc925c 100644
Binary files a/data/tree_frogs_reg_test.rda and b/data/tree_frogs_reg_test.rda differ
diff --git a/man/example_data.Rd b/man/example_data.Rd
index aafd3ca9..a65801a0 100644
--- a/man/example_data.Rd
+++ b/man/example_data.Rd
@@ -101,7 +101,7 @@ to the stimulus) using most all of the other variables as predictors.
The source code for generating these objects is given below.
-\if{html}{\out{
}}\preformatted{# setup: packages, data, resample, basic recipe ------------------------
+\if{html}{\out{
}}\preformatted{# setup: packages, data, resample, basic recipe ------------------------
library(stacks)
library(tune)
library(rsample)
diff --git a/man/stacks_description.Rd b/man/stacks_description.Rd
index 2a87a6fb..556521ca 100644
--- a/man/stacks_description.Rd
+++ b/man/stacks_description.Rd
@@ -9,11 +9,7 @@
\description{
\if{html}{\figure{logo.png}{options: align='right' alt='logo' width='120'}}
-Model stacking is an ensemble technique
- that involves training a model to combine the outputs of many
- diverse statistical models, and has been shown to improve
- predictive performance in a variety of settings. 'stacks'
- implements a grammar for 'tidymodels'-aligned model stacking.
+Model stacking is an ensemble technique that involves training a model to combine the outputs of many diverse statistical models, and has been shown to improve predictive performance in a variety of settings. 'stacks' implements a grammar for 'tidymodels'-aligned model stacking.
}
\seealso{
Useful links:
diff --git a/tests/testthat/helper_data.Rda b/tests/testthat/helper_data.Rda
index 95bc0176..be7a9c6d 100644
Binary files a/tests/testthat/helper_data.Rda and b/tests/testthat/helper_data.Rda differ
diff --git a/tests/testthat/out/data_stack_class.txt b/tests/testthat/out/data_stack_class.txt
index 5b969fda..23229636 100644
--- a/tests/testthat/out/data_stack_class.txt
+++ b/tests/testthat/out/data_stack_class.txt
@@ -1,5 +1,5 @@
> st_class_1
-# A data stack with 1 model definition and 10 candidate members:
-# class_res_rf: 10 model configurations
+# A data stack with 1 model definition and 9.66666666666667 candidate members:
+# class_res_rf: 9.66666666666667 model configurations
# Outcome: reflex (factor)
diff --git a/tests/testthat/out/model_stack_class.txt b/tests/testthat/out/model_stack_class.txt
index 57c6ba60..618f1b6b 100644
--- a/tests/testthat/out/model_stack_class.txt
+++ b/tests/testthat/out/model_stack_class.txt
@@ -2,28 +2,24 @@
Message: -- A stacked ensemble model -------------------------------------
Message:
-Out of 20 possible candidate members, the ensemble retained 22.
-Penalty: 1e-05.
+Out of 19 possible candidate members, the ensemble retained 6.
+Penalty: 0.1.
Mixture: 1.
-Message: Across the 3 classes, there are an average of 7.33 coefficients per class.
+Message: Across the 3 classes, there are an average of 3 coefficients per class.
Message:
-The 10 highest weighted member classes are:
+The 6 highest weighted member classes are:
-# A tibble: 10 x 4
- member type weight class
-
- 1 .pred_mid_class_res_rf_1_04 rand_forest 39.8 low
- 2 .pred_mid_class_res_rf_1_06 rand_forest 35.3 mid
- 3 .pred_mid_class_res_rf_1_09 rand_forest 23.5 mid
- 4 .pred_full_class_res_rf_1_05 rand_forest 21.7 full
- 5 .pred_full_class_res_rf_1_04 rand_forest 17.0 low
- 6 .pred_full_class_res_rf_1_09 rand_forest 16.6 mid
- 7 .pred_mid_class_res_rf_1_02 rand_forest 13.3 low
- 8 .pred_mid_class_res_rf_1_01 rand_forest 13.3 mid
- 9 .pred_mid_class_res_rf_1_10 rand_forest 11.9 low
-10 .pred_mid_class_res_rf_1_03 rand_forest 11.4 low
+# A tibble: 6 x 4
+ member type weight class
+
+1 .pred_full_class_res_rf_1_05 rand_forest 3.75 full
+2 .pred_mid_class_res_rf_1_06 rand_forest 0.674 mid
+3 .pred_full_class_res_rf_1_07 rand_forest 0.411 full
+4 .pred_full_class_res_rf_1_01 rand_forest 0.0957 full
+5 .pred_full_class_res_rf_1_06 rand_forest 0.0193 full
+6 .pred_full_class_res_rf_1_04 rand_forest 0.0110 full
Message:
Members have not yet been fitted with `fit_members()`.
diff --git a/tests/testthat/out/model_stack_class_fit.txt b/tests/testthat/out/model_stack_class_fit.txt
index caeea28e..8bcf81c6 100644
--- a/tests/testthat/out/model_stack_class_fit.txt
+++ b/tests/testthat/out/model_stack_class_fit.txt
@@ -2,26 +2,22 @@
Message: -- A stacked ensemble model -------------------------------------
Message:
-Out of 20 possible candidate members, the ensemble retained 22.
-Penalty: 1e-05.
+Out of 19 possible candidate members, the ensemble retained 6.
+Penalty: 0.1.
Mixture: 1.
-Message: Across the 3 classes, there are an average of 7.33 coefficients per class.
+Message: Across the 3 classes, there are an average of 3 coefficients per class.
Message:
-The 10 highest weighted member classes are:
+The 6 highest weighted member classes are:
-# A tibble: 10 x 4
- member type weight class
-
- 1 .pred_mid_class_res_rf_1_04 rand_forest 39.8 low
- 2 .pred_mid_class_res_rf_1_06 rand_forest 35.3 mid
- 3 .pred_mid_class_res_rf_1_09 rand_forest 23.5 mid
- 4 .pred_full_class_res_rf_1_05 rand_forest 21.7 full
- 5 .pred_full_class_res_rf_1_04 rand_forest 17.0 low
- 6 .pred_full_class_res_rf_1_09 rand_forest 16.6 mid
- 7 .pred_mid_class_res_rf_1_02 rand_forest 13.3 low
- 8 .pred_mid_class_res_rf_1_01 rand_forest 13.3 mid
- 9 .pred_mid_class_res_rf_1_10 rand_forest 11.9 low
-10 .pred_mid_class_res_rf_1_03 rand_forest 11.4 low
+# A tibble: 6 x 4
+ member type weight class
+
+1 .pred_full_class_res_rf_1_05 rand_forest 3.75 full
+2 .pred_mid_class_res_rf_1_06 rand_forest 0.674 mid
+3 .pred_full_class_res_rf_1_07 rand_forest 0.411 full
+4 .pred_full_class_res_rf_1_01 rand_forest 0.0957 full
+5 .pred_full_class_res_rf_1_06 rand_forest 0.0193 full
+6 .pred_full_class_res_rf_1_04 rand_forest 0.0110 full
diff --git a/tests/testthat/out/model_stack_log.txt b/tests/testthat/out/model_stack_log.txt
index bcee8f77..e8163c35 100644
--- a/tests/testthat/out/model_stack_log.txt
+++ b/tests/testthat/out/model_stack_log.txt
@@ -2,18 +2,19 @@
Message: -- A stacked ensemble model -------------------------------------
Message:
-Out of 10 possible candidate members, the ensemble retained 2.
-Penalty: 0.1.
+Out of 10 possible candidate members, the ensemble retained 3.
+Penalty: 1e-05.
Mixture: 1.
Message:
-The 2 highest weighted member classes are:
+The 3 highest weighted member classes are:
-# A tibble: 2 x 3
+# A tibble: 3 x 3
member type weight
-1 .pred_yes_log_res_rf_1_03 rand_forest 3.54
-2 .pred_yes_log_res_rf_1_06 rand_forest 0.0457
+1 .pred_yes_log_res_rf_1_05 rand_forest 3.56
+2 .pred_yes_log_res_rf_1_02 rand_forest 3.22
+3 .pred_yes_log_res_rf_1_09 rand_forest 0.226
Message:
Members have not yet been fitted with `fit_members()`.
diff --git a/tests/testthat/out/model_stack_log_fit.txt b/tests/testthat/out/model_stack_log_fit.txt
index 543dee23..37d00985 100644
--- a/tests/testthat/out/model_stack_log_fit.txt
+++ b/tests/testthat/out/model_stack_log_fit.txt
@@ -2,16 +2,17 @@
Message: -- A stacked ensemble model -------------------------------------
Message:
-Out of 10 possible candidate members, the ensemble retained 2.
-Penalty: 0.1.
+Out of 10 possible candidate members, the ensemble retained 3.
+Penalty: 1e-05.
Mixture: 1.
Message:
-The 2 highest weighted member classes are:
+The 3 highest weighted member classes are:
-# A tibble: 2 x 3
+# A tibble: 3 x 3
member type weight
-1 .pred_yes_log_res_rf_1_03 rand_forest 3.54
-2 .pred_yes_log_res_rf_1_06 rand_forest 0.0457
+1 .pred_yes_log_res_rf_1_05 rand_forest 3.56
+2 .pred_yes_log_res_rf_1_02 rand_forest 3.22
+3 .pred_yes_log_res_rf_1_09 rand_forest 0.226
diff --git a/tests/testthat/out/model_stack_reg.txt b/tests/testthat/out/model_stack_reg.txt
index b3c80d11..c533716f 100644
--- a/tests/testthat/out/model_stack_reg.txt
+++ b/tests/testthat/out/model_stack_reg.txt
@@ -2,19 +2,18 @@
Message: -- A stacked ensemble model -------------------------------------
Message:
-Out of 5 possible candidate members, the ensemble retained 3.
+Out of 5 possible candidate members, the ensemble retained 2.
Penalty: 0.1.
Mixture: 1.
Message:
-The 3 highest weighted members are:
+The 2 highest weighted members are:
-# A tibble: 3 x 3
+# A tibble: 2 x 3
member type weight
-1 reg_res_svm_1_5 svm_rbf 2.88
-2 reg_res_svm_1_3 svm_rbf 0.895
-3 reg_res_svm_1_1 svm_rbf 0.410
+1 reg_res_svm_1_3 svm_rbf 1.26
+2 reg_res_svm_1_2 svm_rbf 0.136
Message:
Members have not yet been fitted with `fit_members()`.
diff --git a/tests/testthat/out/model_stack_reg_fit.txt b/tests/testthat/out/model_stack_reg_fit.txt
index bb5ea420..63307382 100644
--- a/tests/testthat/out/model_stack_reg_fit.txt
+++ b/tests/testthat/out/model_stack_reg_fit.txt
@@ -2,17 +2,16 @@
Message: -- A stacked ensemble model -------------------------------------
Message:
-Out of 5 possible candidate members, the ensemble retained 3.
+Out of 5 possible candidate members, the ensemble retained 2.
Penalty: 0.1.
Mixture: 1.
Message:
-The 3 highest weighted members are:
+The 2 highest weighted members are:
-# A tibble: 3 x 3
+# A tibble: 2 x 3
member type weight
-1 reg_res_svm_1_5 svm_rbf 2.88
-2 reg_res_svm_1_3 svm_rbf 0.895
-3 reg_res_svm_1_1 svm_rbf 0.410
+1 reg_res_svm_1_3 svm_rbf 1.26
+2 reg_res_svm_1_2 svm_rbf 0.136
diff --git a/tests/testthat/test_collect_parameters.R b/tests/testthat/test_collect_parameters.R
index ba0395b8..a3b106bd 100644
--- a/tests/testthat/test_collect_parameters.R
+++ b/tests/testthat/test_collect_parameters.R
@@ -55,7 +55,7 @@ test_that("collect_parameters on a data stack works (regression)", {
expect_equal(nrow(res), 5)
expect_equal(ncol(res2), 2)
- expect_equal(nrow(res2), 9)
+ expect_equal(nrow(res2), 10)
expect_equal(ncol(res3), 1)
expect_equal(nrow(res3), 1)
@@ -73,7 +73,7 @@ test_that("collect_parameters on a model stack works (regression)", {
expect_equal(nrow(res), 5)
expect_equal(ncol(res2), 3)
- expect_equal(nrow(res2), 9)
+ expect_equal(nrow(res2), 10)
expect_true(
all(
@@ -100,5 +100,5 @@ test_that("collect_parameters works (classification)", {
expect_equal(nrow(res), 10)
expect_equal(ncol(res2), 6)
- expect_equal(nrow(res2), 60)
+ expect_equal(nrow(res2), 57)
})