diff --git a/NEWS.md b/NEWS.md index 392774fd7..b056baad5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,8 @@ * The list column produced when creating survival probability predictions is now always called `.pred` (with `.pred_survival` being used inside of the list column). +* Fixed outcome type checking affecting a subset of regression models (#625). + ## Other Changes * When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates). diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 0d20d5b93..9e820af69 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -6,17 +6,10 @@ form_form <- function(object, control, env, ...) { + check_outcome(eval_tidy(env$formula[[2]], env$data), object) + # prob rewrite this as simple subset/levels y_levels <- levels_from_formula(env$formula, env$data) - - if (object$mode == "classification") { - if (!inherits(env$data, "tbl_spark") && is.null(y_levels)) - rlang::abort("For a classification model, the outcome should be a factor.") - } else if (object$mode == "regression") { - if (!inherits(env$data, "tbl_spark") && !is.null(y_levels)) - rlang::abort("For a regression model, the outcome should be numeric.") - } - object <- check_mode(object, y_levels) # if descriptors are needed, update descr_env with the calculated values @@ -150,14 +143,7 @@ form_xy <- function(object, control, env, env$x <- data_obj$x env$y <- data_obj$y - res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object) - if (object$mode == "classification") { - if (is.null(res$lvl)) - rlang::abort("For a classification model, the outcome should be a factor.") - } else if (object$mode == "regression") { - if (!is.null(res$lvl)) - rlang::abort("For a regression model, the outcome should be numeric.") - } + check_outcome(env$y, object) res <- xy_xy( object = object, diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 0f1c0dd4c..f17bc120e 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -247,6 +247,16 @@ test_that('lm execution', { regexp = "For a regression model" ) + expect_error( + res <- fit_xy( + hpc_basic, + x = hpc[, num_pred], + y = as.character(hpc$class), + control = ctrl + ), + regexp = "For a regression model" + ) + expect_error( res <- fit( hpc_basic,