Skip to content

Commit

Permalink
Merge pull request #919 from tidymodels/check-outcome
Browse files Browse the repository at this point in the history
More verbose check_outcome()
  • Loading branch information
EmilHvitfeldt authored Mar 13, 2023
2 parents 00f3cdb + 3f924d8 commit 2ead20c
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 5 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

* `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545).

* Functions now indicate what class the outcome was if the outcome is the wrong class (#887).

# parsnip 1.0.4

Expand Down
18 changes: 15 additions & 3 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,22 @@ check_outcome <- function(y, spec) {
if (spec$mode == "regression") {
outcome_is_numeric <- if (is.atomic(y)) {is.numeric(y)} else {all(map_lgl(y, is.numeric))}
if (!outcome_is_numeric) {
rlang::abort("For a regression model, the outcome should be numeric.")
cls <- class(y)[[1]]
abort(paste0(
"For a regression model, the outcome should be `numeric`, ",
"not a `", cls, "`."
))
}
}

if (spec$mode == "classification") {
outcome_is_factor <- if (is.atomic(y)) {is.factor(y)} else {all(map_lgl(y, is.factor))}
if (!outcome_is_factor) {
rlang::abort("For a classification model, the outcome should be a factor.")
cls <- class(y)[[1]]
abort(paste0(
"For a classification model, the outcome should be a `factor`, ",
"not a `", cls, "`."
))
}

if (inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2) {
Expand All @@ -361,7 +369,11 @@ check_outcome <- function(y, spec) {
if (spec$mode == "censored regression") {
outcome_is_surv <- inherits(y, "Surv")
if (!outcome_is_surv) {
rlang::abort("For a censored regression model, the outcome should be a `Surv` object.")
cls <- class(y)[[1]]
abort(paste0(
"For a censored regression model, the outcome should be a `Surv` object, ",
"not a `", cls, "`."
))
}
}

Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,27 @@
Error in `fn()`:
! Please use `new_data` instead of `newdata`.

# check_outcome works as expected

Code
check_outcome(factor(1:2), reg_spec)
Condition
Error in `check_outcome()`:
! For a regression model, the outcome should be `numeric`, not a `factor`.

---

Code
check_outcome(1:2, class_spec)
Condition
Error in `check_outcome()`:
! For a classification model, the outcome should be a `factor`, not a `integer`.

---

Code
check_outcome(1:2, cens_spec)
Condition
Error in `check_outcome()`:
! For a censored regression model, the outcome should be a `Surv` object, not a `integer`.

47 changes: 46 additions & 1 deletion tests/testthat/test_misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,51 @@ test_that('set_engine works as a generic', {
test_that('check_for_newdata points out correct context', {
fn <- function(...) {check_for_newdata(...); invisible()}
expect_snapshot(error = TRUE,
fn(newdata = "boop!")
fn(newdata = "boop!")
)
})

test_that('check_outcome works as expected', {
reg_spec <- linear_reg()

expect_no_error(
check_outcome(1:2, reg_spec)
)

expect_no_error(
check_outcome(mtcars, reg_spec)
)

expect_snapshot(
error = TRUE,
check_outcome(factor(1:2), reg_spec)
)

class_spec <- logistic_reg()

expect_no_error(
check_outcome(factor(1:2), class_spec)
)

expect_no_error(
check_outcome(lapply(mtcars, as.factor), class_spec)
)

expect_snapshot(
error = TRUE,
check_outcome(1:2, class_spec)
)

# Fake specification to avoid having to load {censored}
cens_spec <- logistic_reg()
cens_spec$mode <- "censored regression"

expect_no_error(
check_outcome(survival::Surv(1, 1), cens_spec)
)

expect_snapshot(
error = TRUE,
check_outcome(1:2, cens_spec)
)
})
2 changes: 1 addition & 1 deletion tests/testthat/test_nearest_neighbor_kknn.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ test_that('kknn execution', {
x = hpc[, num_pred],
y = hpc$input_fields
),
regexp = "outcome should be a factor"
regexp = "outcome should be a `factor`"
)

# nominal
Expand Down

0 comments on commit 2ead20c

Please sign in to comment.