Skip to content

Commit

Permalink
less computationally expensive examples
Browse files Browse the repository at this point in the history
  • Loading branch information
krzyzinskim committed Sep 29, 2023
1 parent 8dd891d commit c54b77f
Show file tree
Hide file tree
Showing 21 changed files with 48 additions and 82 deletions.
5 changes: 2 additions & 3 deletions R/plot_contribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,16 @@
#'
#'
#' @examples
#' \dontrun{
#' library(xgboost)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' target <- fifa20$target
#' param <- list(objective = "reg:squarederror", max_depth = 3)
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200)
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target,
#' nrounds = 20, verbose = FALSE)
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
#' x <- head(data, 1)
#' shap <- treeshap(unified_model, x)
#' plot_contribution(shap, 1, min_max = c(0, 120000000))
#' }
plot_contribution <- function(treeshap,
obs = 1,
max_vars = 5,
Expand Down
5 changes: 2 additions & 3 deletions R/plot_feature_dependence.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@
#'
#'
#' @examples
#' \dontrun{
#' library(xgboost)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' target <- fifa20$target
#' param <- list(objective = "reg:squarederror", max_depth = 3)
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200)
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target,
#' nrounds = 20, verbose = FALSE)
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
#' x <- head(data, 100)
#' shaps <- treeshap(unified_model, x)
#' plot_feature_dependence(shaps, variable = "overall")
#' }
plot_feature_dependence <- function(treeshap, variable,
title = "Feature Dependence", subtitle = NULL) {

Expand Down
5 changes: 2 additions & 3 deletions R/plot_feature_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@
#'
#'
#' @examples
#' \dontrun{
#' library(xgboost)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' target <- fifa20$target
#' param <- list(objective = "reg:squarederror", max_depth = 3)
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target, nrounds = 200)
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target,
#' nrounds = 20, verbose = FALSE)
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
#' shaps <- treeshap(unified_model, as.matrix(head(data, 3)))
#' plot_feature_importance(shaps, max_vars = 4)
#' }
plot_feature_importance <- function(treeshap,
desc_sorting = TRUE,
max_vars = ncol(shaps),
Expand Down
2 changes: 0 additions & 2 deletions R/plot_interaction.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@
#'
#'
#' @examples
#' \donttest{
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' target <- fifa20$target
#' param2 <- list(objective = "reg:squarederror", max_depth = 5)
#' xgb_model2 <- xgboost::xgboost(as.matrix(data), params = param2, label = target, nrounds = 10)
#' unified_model2 <- xgboost.unify(xgb_model2, data)
#' inters <- treeshap(unified_model2, as.matrix(data[1:50, ]), interactions = TRUE)
#' plot_interaction(inters, "dribbling", "defending")
#' }
plot_interaction <- function(treeshap, var1, var2,
title = "SHAP Interaction Value Plot",
subtitle = "") {
Expand Down
19 changes: 8 additions & 11 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@
#' @export
#'
#' @examples
#' \dontrun{
#' library(gbm)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' data['value_eur'] <- fifa20$target
#' gbm_model <- gbm::gbm(
#' formula = value_eur ~ .,
#' data = data,
#' distribution = "laplace",
#' n.trees = 1000,
#' cv.folds = 2,
#' interaction.depth = 2,
#' n.cores = 1)
#' unified <- gbm.unify(gbm_model, data)
#' predict(unified, data[3:7, ])
#'}
#' formula = value_eur ~ .,
#' data = data,
#' distribution = "laplace",
#' n.trees = 20,
#' interaction.depth = 4,
#' n.cores = 1)
#' unified <- gbm.unify(gbm_model, data)
#' predict(unified, data[2001:2005, ])
predict.model_unified <- function(object, x, ...) {
unified_model <- object
model <- unified_model$model
Expand Down
15 changes: 6 additions & 9 deletions R/set_reference_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,18 @@
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}}
#'
#' @examples
#' \dontrun{
#' library(gbm)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' data['value_eur'] <- fifa20$target
#' gbm_model <- gbm::gbm(
#' formula = value_eur ~ .,
#' data = data,
#' distribution = "laplace",
#' n.trees = 1000,
#' cv.folds = 2,
#' interaction.depth = 2,
#' n.cores = 1)
#' formula = value_eur ~ .,
#' data = data,
#' distribution = "laplace",
#' n.trees = 20,
#' interaction.depth = 4,
#' n.cores = 1)
#' unified <- gbm.unify(gbm_model, data)
#' set_reference_dataset(unified, data[200:700, ])
#'}
set_reference_dataset <- function(unified_model, x) {
model <- unified_model$model
data <- x
Expand Down
4 changes: 1 addition & 3 deletions R/unify_gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,19 @@
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}}
#'
#' @examples
#'\donttest{
#' library(gbm)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' data['value_eur'] <- fifa20$target
#' gbm_model <- gbm::gbm(
#' formula = value_eur ~ .,
#' data = data,
#' distribution = "gaussian",
#' n.trees = 50,
#' n.trees = 20,
#' interaction.depth = 4,
#' n.cores = 1)
#' unified_model <- gbm.unify(gbm_model, data)
#' shaps <- treeshap(unified_model, data[1:2,])
#' plot_contribution(shaps, obs = 1)
#' }
gbm.unify <- function(gbm_model, data) {
if(!inherits(gbm_model,'gbm')) {
stop('Object gbm_model was not of class "gbm"')
Expand Down
5 changes: 2 additions & 3 deletions R/unify_lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}}
#'
#' @examples
#' \donttest{
#' library(lightgbm)
#' param_lgbm <- list(objective = "regression", max_depth = 2, force_row_wise = TRUE)
#' param_lgbm <- list(objective = "regression", max_depth = 2,
#' force_row_wise = TRUE, num_iterations = 20)
#' data_fifa <- fifa20$data[!colnames(fifa20$data) %in%
#' c('work_rate', 'value_eur', 'gk_diving', 'gk_handling',
#' 'gk_kicking', 'gk_reflexes', 'gk_speed', 'gk_positioning')]
Expand All @@ -40,7 +40,6 @@
#' unified_model <- lightgbm.unify(lgb_model, sparse_data)
#' shaps <- treeshap(unified_model, data[1:2, ])
#' plot_contribution(shaps, obs = 1)
#' }
lightgbm.unify <- function(lgb_model, data, recalculate = FALSE) {
if (!requireNamespace("lightgbm", quietly = TRUE)) {
stop("Package \"lightgbm\" needed for this function to work. Please install it.",
Expand Down
2 changes: 0 additions & 2 deletions R/unify_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}}
#'
#' @examples
#' \donttest{
#' library(xgboost)
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
#' target <- fifa20$target
Expand All @@ -31,7 +30,6 @@
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
#' shaps <- treeshap(unified_model, data[1:2,])
#' plot_contribution(shaps, obs = 1)
#' }
#'
xgboost.unify <- function(xgb_model, data, recalculate = FALSE) {
if (!requireNamespace("xgboost", quietly = TRUE)) {
Expand Down
4 changes: 1 addition & 3 deletions man/gbm.unify.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/lightgbm.unify.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/plot_contribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/plot_feature_dependence.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions man/plot_feature_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/plot_interaction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 8 additions & 11 deletions man/predict.model_unified.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 6 additions & 9 deletions man/set_reference_dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions man/treeshap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c54b77f

Please sign in to comment.