Skip to content

Commit

Permalink
patch params argument with xgboost engine in boost_tree() (#787)
Browse files Browse the repository at this point in the history
* patch `params` argument with `xgboost` engine in `boost_tree()`

* remove + add snapshots from previous PRs

* update snaps with new help-page reference

Co-authored-by: Max Kuhn <mxkuhn@gmail.com>
  • Loading branch information
simonpcouch and topepo authored Aug 17, 2022
1 parent 9e36249 commit 6c5482a
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 52 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# parsnip (development version)


* Enabled passing additional engine arguments with the xgboost `boost_tree()` engine. To supply engine-specific arguments that are documented in `xgboost::xgb.train()` as arguments to be passed via `params`, supply the list elements directly as named arguments to `set_engine()`. Read more in `?details_boost_tree_xgboost` (#787).

# parsnip 1.0.0

## Model Specification Changes
Expand Down
89 changes: 59 additions & 30 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,6 @@ check_args.boost_tree <- function(object) {
#' @param counts A logical. If `FALSE`, `colsample_bynode` and
#' `colsample_bytree` are both assumed to be _proportions_ of the proportion of
#' columns affects (instead of counts).
#' @param objective A single string (or NULL) that defines the loss function that
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
#' NULL, an appropriate loss function is chosen.
#' @param event_level For binary classification, this is a single string of either
#' `"first"` or `"second"` to pass along describing which level of the outcome
#' should be considered the "event".
Expand All @@ -227,7 +224,7 @@ xgb_train <- function(
x, y, weights = NULL,
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
validation = 0, early_stop = NULL, counts = TRUE,
event_level = c("first", "second"), ...) {

event_level <- rlang::arg_match(event_level, c("first", "second"))
Expand All @@ -248,18 +245,6 @@ xgb_train <- function(
}
}

if (is.null(objective)) {
if (is.numeric(y)) {
objective <- "reg:squarederror"
} else {
if (num_class == 2) {
objective <- "binary:logistic"
} else {
objective <- "multi:softprob"
}
}
}

n <- nrow(x)
p <- ncol(x)

Expand Down Expand Up @@ -300,35 +285,79 @@ xgb_train <- function(
colsample_bytree = colsample_bytree,
colsample_bynode = colsample_bynode,
min_child_weight = min(min_child_weight, n),
subsample = subsample,
objective = objective
subsample = subsample
)

main_args <- list(
data = quote(x$data),
watchlist = quote(x$watchlist),
params = arg_list,
nrounds = nrounds,
early_stopping_rounds = early_stop
others <- process_others(others, arg_list)

main_args <- c(
list(
data = quote(x$data),
watchlist = quote(x$watchlist),
params = arg_list,
nrounds = nrounds,
early_stopping_rounds = early_stop
),
others
)

if (is.null(main_args$objective)) {
if (is.numeric(y)) {
main_args$objective <- "reg:squarederror"
} else {
if (num_class == 2) {
main_args$objective <- "binary:logistic"
} else {
main_args$objective <- "multi:softprob"
}
}
}

if (!is.null(num_class) && num_class > 2) {
main_args$num_class <- num_class
}

call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)

# override or add some other args
eval_tidy(call, env = current_env())
}

process_others <- function(others, arg_list) {
guarded <- c("data", "weights", "num_class", "watchlist")
guarded_supplied <- names(others)[names(others) %in% guarded]

if (length(guarded_supplied) > 0) {
cli::cli_warn(
c(
"!" = "{cli::qty(guarded_supplied)} The argument{?s} {.arg {guarded_supplied}} \
{?is/are} guarded by parsnip and will not be passed to {.fun xgb.train}."
),
class = "xgboost_guarded_warning"
)
}

others <-
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
others[!(names(others) %in% guarded)]

if (!is.null(others$params)) {
cli::cli_warn(
c(
"!" = "Please supply elements of the `params` list argument as main arguments \
to `set_engine()` rather than as part of `params`.",
"i" = "See `?details_boost_tree_xgboost` for more information."
),
class = "xgboost_params_warning"
)

params <- others$params[!names(others$params) %in% names(arg_list)]
others <- c(others[names(others) != "params"], params)
}

if (!(any(names(others) == "verbose"))) {
others$verbose <- 0
}
if (length(others) > 0) {
call <- rlang::call_modify(call, !!!others)
}

eval_tidy(call, env = current_env())
others
}

recalc_param <- function(x, counts, denom) {
Expand Down
51 changes: 50 additions & 1 deletion man/details_boost_tree_xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 23 additions & 1 deletion man/rmd/boost_tree_xgboost.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ For classification, non-numeric outcomes (i.e., factors) are internally converte

## Other details

### Interfacing with the `params` argument

The xgboost function that parsnip indirectly wraps, [xgboost::xgb.train()], takes most arguments via the `params` list argument. To supply engine-specific arguments that are documented in [xgboost::xgb.train()] as arguments to be passed via `params`, supply the list elements directly as named arguments to [set_engine()] rather than as elements in `params`. For example, pass a non-default evaluation metric like this:

```{r}
# good
boost_tree() %>%
set_engine("xgboost", eval_metric = "mae")
```

...rather than this:

```{r}
# bad
boost_tree() %>%
set_engine("xgboost", params = list(eval_metric = "mae"))
```

parsnip will then route arguments as needed. In the case that arguments are passed to `params` via [set_engine()], parsnip will warn and re-route the arguments as needed. Note, though, that arguments passed to `params` cannot be tuned.

### Sparse matrices

xgboost requires the data to be in a sparse format. If your predictor data are already in this format, then use [fit_xy.model_spec()] to pass it to the model function. Otherwise, parsnip converts the data to this format.
Expand All @@ -78,9 +98,11 @@ By default, the model is trained without parallel processing. This can be change
```{r child = "template-early-stopping.Rmd"}
```

Note that, since the `validation` argument provides an alternative interface to `watchlist`, the `watchlist` argument is guarded by parsnip and will be ignored (with a warning) if passed.

### Objective function

parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()].
parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()] directly.

## Examples

Expand Down
44 changes: 43 additions & 1 deletion man/rmd/boost_tree_xgboost.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,46 @@ For classification, non-numeric outcomes (i.e., factors) are internally converte

## Other details

### Interfacing with the `params` argument

The xgboost function that parsnip indirectly wraps, [xgboost::xgb.train()], takes most arguments via the `params` list argument. To supply engine-specific arguments that are documented in [xgboost::xgb.train()] as arguments to be passed via `params`, supply the list elements directly as named arguments to [set_engine()] rather than as elements in `params`. For example, pass a non-default evaluation metric like this:


```r
# good
boost_tree() %>%
set_engine("xgboost", eval_metric = "mae")
```

```
## Boosted Tree Model Specification (unknown)
##
## Engine-Specific Arguments:
## eval_metric = mae
##
## Computational engine: xgboost
```

...rather than this:


```r
# bad
boost_tree() %>%
set_engine("xgboost", params = list(eval_metric = "mae"))
```

```
## Boosted Tree Model Specification (unknown)
##
## Engine-Specific Arguments:
## params = list(eval_metric = "mae")
##
## Computational engine: xgboost
```

parsnip will then route arguments as needed. In the case that arguments are passed to `params` via [set_engine()], parsnip will warn and re-route the arguments as needed. Note, though, that arguments passed to `params` cannot be tuned.

### Sparse matrices

xgboost requires the data to be in a sparse format. If your predictor data are already in this format, then use [fit_xy.model_spec()] to pass it to the model function. Otherwise, parsnip converts the data to this format.
Expand Down Expand Up @@ -137,9 +177,11 @@ The best way to use this feature is in conjunction with an _internal validation

If the model specification has `early_stop >= trees`, `early_stop` is converted to `trees - 1` and a warning is issued.

Note that, since the `validation` argument provides an alternative interface to `watchlist`, the `watchlist` argument is guarded by parsnip and will be ignored (with a warning) if passed.

### Objective function

parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()].
parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()] directly.

## Examples

Expand Down
5 changes: 0 additions & 5 deletions man/xgb_train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/_snaps/boost_tree_xgboost.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# interface to param arguments

! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`.
i See `?details_boost_tree_xgboost` for more information.

---

! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`.
i See `?details_boost_tree_xgboost` for more information.

---

! The argument `watchlist` is guarded by parsnip and will not be passed to `xgb.train()`.

---

! The arguments `watchlist` and `data` are guarded by parsnip and will not be passed to `xgb.train()`.

---

! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`.
i See `?details_boost_tree_xgboost` for more information.

14 changes: 0 additions & 14 deletions tests/testthat/_snaps/proportional_hazards.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# printing

Code
proportional_hazards()
Message
parsnip could not locate an implementation for `proportional_hazards` censored regression model specifications using the `survival` engine.
i The parsnip extension package censored implements support for this specification.
i Please install (if needed) and load to continue.
Output
Proportional Hazards Model Specification (censored regression)
Computational engine: survival

# updating

Code
Expand Down
Loading

0 comments on commit 6c5482a

Please sign in to comment.