Skip to content

Commit

Permalink
Use check_outcome for all fit paths (#625)
Browse files Browse the repository at this point in the history
* Use `check_outcome` for all fit paths

* Update NEWS
  • Loading branch information
juliasilge authored Dec 15, 2021
1 parent b19f0e7 commit b6db676
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

* The list column produced when creating survival probability predictions is now always called `.pred` (with `.pred_survival` being used inside of the list column).

* Fixed outcome type checking affecting a subset of regression models (#625).

## Other Changes

* When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates).
Expand Down
20 changes: 3 additions & 17 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,10 @@
form_form <-
function(object, control, env, ...) {

check_outcome(eval_tidy(env$formula[[2]], env$data), object)

# prob rewrite this as simple subset/levels
y_levels <- levels_from_formula(env$formula, env$data)

if (object$mode == "classification") {
if (!inherits(env$data, "tbl_spark") && is.null(y_levels))
rlang::abort("For a classification model, the outcome should be a factor.")
} else if (object$mode == "regression") {
if (!inherits(env$data, "tbl_spark") && !is.null(y_levels))
rlang::abort("For a regression model, the outcome should be numeric.")
}

object <- check_mode(object, y_levels)

# if descriptors are needed, update descr_env with the calculated values
Expand Down Expand Up @@ -150,14 +143,7 @@ form_xy <- function(object, control, env,
env$x <- data_obj$x
env$y <- data_obj$y

res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object)
if (object$mode == "classification") {
if (is.null(res$lvl))
rlang::abort("For a classification model, the outcome should be a factor.")
} else if (object$mode == "regression") {
if (!is.null(res$lvl))
rlang::abort("For a regression model, the outcome should be numeric.")
}
check_outcome(env$y, object)

res <- xy_xy(
object = object,
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test_linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ test_that('lm execution', {
regexp = "For a regression model"
)

expect_error(
res <- fit_xy(
hpc_basic,
x = hpc[, num_pred],
y = as.character(hpc$class),
control = ctrl
),
regexp = "For a regression model"
)

expect_error(
res <- fit(
hpc_basic,
Expand Down

0 comments on commit b6db676

Please sign in to comment.