From 1dc84b5973e4222966917fdc7f2f39450ca05d70 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Tue, 22 Mar 2022 10:05:34 +0000 Subject: [PATCH 1/3] export `xgb_pred()` --- R/boost_tree.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/boost_tree.R b/R/boost_tree.R index 86644820c..5d274034c 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -383,6 +383,10 @@ maybe_proportion <- function(x, nm) { } } +#' @rdname xgb_train +#' @param newdata A rectangular data object, such as a data frame. +#' @keywords internal +#' @export xgb_pred <- function(object, newdata, ...) { if (!inherits(newdata, "xgb.DMatrix")) { newdata <- maybe_matrix(newdata) From bbb89f8898b241f90c1ec12a0ef561ae4efbcd1d Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 24 Mar 2022 13:20:26 +0000 Subject: [PATCH 2/3] rename function and arg --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/boost_tree.R | 22 +++++++++++----------- R/boost_tree_data.R | 20 ++++++++++---------- man/xgb_train.Rd | 11 ++++++++--- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9adfe4bb7..c52d95be1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -89,4 +89,4 @@ Config/rcmdcheck/ignore-inconsequential-notes: true Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.1.2 +RoxygenNote: 7.1.2.9000 diff --git a/NAMESPACE b/NAMESPACE index 5236d3b7f..485696775 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -294,6 +294,7 @@ export(update_main_parameters) export(update_model_info_file) export(varying) export(varying_args) +export(xgb_predict) export(xgb_train) importFrom(dplyr,arrange) importFrom(dplyr,bind_cols) diff --git a/R/boost_tree.R b/R/boost_tree.R index 5d274034c..28aa73eb0 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -217,8 +217,8 @@ check_args.boost_tree <- function(object) { #' Boosted trees via xgboost #' -#' `xgb_train` is a wrapper for `xgboost` tree-based models where all of the -#' model arguments are in the main function. +#' `xgb_train()` and `xgb_predict()` are wrappers for `xgboost` tree-based +#' models where all of the model arguments are in the main function. #' #' @param x A data frame or matrix of predictors #' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. @@ -251,7 +251,7 @@ check_args.boost_tree <- function(object) { #' @param event_level For binary classification, this is a single string of either #' `"first"` or `"second"` to pass along describing which level of the outcome #' should be considered the "event". -#' @param ... Other options to pass to `xgb.train`. +#' @param ... Other options to pass to `xgb.train()` or xgboost's method for `predict()`. #' @return A fitted `xgboost` object. #' @keywords internal #' @export @@ -384,16 +384,16 @@ maybe_proportion <- function(x, nm) { } #' @rdname xgb_train -#' @param newdata A rectangular data object, such as a data frame. +#' @param new_data A rectangular data object, such as a data frame. #' @keywords internal #' @export -xgb_pred <- function(object, newdata, ...) { - if (!inherits(newdata, "xgb.DMatrix")) { - newdata <- maybe_matrix(newdata) - newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA) +xgb_predict <- function(object, new_data, ...) { + if (!inherits(new_data, "xgb.DMatrix")) { + new_data <- maybe_matrix(new_data) + new_data <- xgboost::xgb.DMatrix(data = new_data, missing = NA) } - res <- predict(object, newdata, ...) + res <- predict(object, new_data, ...) x <- switch( object$params$objective, @@ -486,9 +486,9 @@ multi_predict._xgb.Booster <- } xgb_by_tree <- function(tree, object, new_data, type, ...) { - pred <- xgb_pred( + pred <- xgb_predict( object$fit, - newdata = new_data, + new_data = new_data, iterationrange = c(1, tree + 1), ntreelimit = NULL ) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index e7f7e9802..63c90b2d9 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -108,8 +108,8 @@ set_pred( value = list( pre = NULL, post = NULL, - func = c(fun = "xgb_pred"), - args = list(object = quote(object$fit), newdata = quote(new_data)) + func = c(fun = "xgb_predict"), + args = list(object = quote(object$fit), new_data = quote(new_data)) ) ) @@ -121,8 +121,8 @@ set_pred( value = list( pre = NULL, post = NULL, - func = c(fun = "xgb_pred"), - args = list(object = quote(object$fit), newdata = quote(new_data)) + func = c(fun = "xgb_predict"), + args = list(object = quote(object$fit), new_data = quote(new_data)) ) ) @@ -170,8 +170,8 @@ set_pred( } x }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = list(object = quote(object$fit), newdata = quote(new_data)) + func = c(pkg = NULL, fun = "xgb_predict"), + args = list(object = quote(object$fit), new_data = quote(new_data)) ) ) @@ -196,8 +196,8 @@ set_pred( colnames(x) <- object$lvl x }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = list(object = quote(object$fit), newdata = quote(new_data)) + func = c(pkg = NULL, fun = "xgb_predict"), + args = list(object = quote(object$fit), new_data = quote(new_data)) ) ) @@ -209,8 +209,8 @@ set_pred( value = list( pre = NULL, post = NULL, - func = c(fun = "xgb_pred"), - args = list(object = quote(object$fit), newdata = quote(new_data)) + func = c(fun = "xgb_predict"), + args = list(object = quote(object$fit), new_data = quote(new_data)) ) ) diff --git a/man/xgb_train.Rd b/man/xgb_train.Rd index 9b963ad11..580e3b7a6 100644 --- a/man/xgb_train.Rd +++ b/man/xgb_train.Rd @@ -2,6 +2,7 @@ % Please edit documentation in R/boost_tree.R \name{xgb_train} \alias{xgb_train} +\alias{xgb_predict} \title{Boosted trees via xgboost} \usage{ xgb_train( @@ -22,6 +23,8 @@ xgb_train( event_level = c("first", "second"), ... ) + +xgb_predict(object, new_data, ...) } \arguments{ \item{x}{A data frame or matrix of predictors} @@ -70,13 +73,15 @@ columns affects (instead of counts).} \code{"first"} or \code{"second"} to pass along describing which level of the outcome should be considered the "event".} -\item{...}{Other options to pass to \code{xgb.train}.} +\item{...}{Other options to pass to \code{xgb.train()} or xgboost's method for \code{predict()}.} + +\item{new_data}{A rectangular data object, such as a data frame.} } \value{ A fitted \code{xgboost} object. } \description{ -\code{xgb_train} is a wrapper for \code{xgboost} tree-based models where all of the -model arguments are in the main function. +\code{xgb_train()} and \code{xgb_predict()} are wrappers for \code{xgboost} tree-based +models where all of the model arguments are in the main function. } \keyword{internal} From 7bae7ce67d8c984903e85e60869cc6a04662dc5d Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 24 Mar 2022 13:28:53 +0000 Subject: [PATCH 3/3] update news --- NEWS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NEWS.md b/NEWS.md index 6235be5f0..91079ba9b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # parsnip (development version) +* Exported `xgb_predict()` which wraps xgboost's `predict()` method for use with parsnip extension packages (#688). + + # parsnip 0.2.1 * Fixed a major bug in spark models induced in the previous version (#671). @@ -7,6 +10,7 @@ * Updated the parsnip add-in with new models and engines. * Updated parameter ranges for some `tunable()` methods and added a missing engine argument for brulee models. + * Added information about how to install the mixOmics package for PLS models (#680)