Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot trees in R #1222

Closed
sosmond opened this issue Jan 27, 2018 · 13 comments
Closed

Plot trees in R #1222

sosmond opened this issue Jan 27, 2018 · 13 comments

Comments

@sosmond
Copy link

sosmond commented Jan 27, 2018

The python package has plotting functions to visualize the trees grown. Is there a plan to add the same functionalities to the R package? Something like xgboost's xgb.plot.tree and xgb.plot.multi.trees using DiagrammeR.

@Laurae2
Copy link
Contributor

Laurae2 commented Jan 30, 2018

@yanyachen any idea?

@zkurtz
Copy link
Contributor

zkurtz commented Mar 6, 2018

These go above and beyond what's current in xgb.plot.tree; would be great to have them in R for LightGBM http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8017582

@tantrev
Copy link

tantrev commented Mar 26, 2018

A function similar to xgb.plot.multi.trees method would be especially useful for biological applications like setting gates for cell sorting, where model interpretation is pivotal.

@randxie
Copy link
Contributor

randxie commented Jun 5, 2018

@Laurae2 if no one is working this, I can help with the feature request.

@zkurtz
Copy link
Contributor

zkurtz commented Jun 6, 2018

@randxie In addition to my link above, another possible starting point is to extend SHAP (which is already implemented in python-lightgbm): https://github.com/slundberg/shap/blob/master/notebooks/Census%20income%20classification%20with%20LightGBM.ipynb

@Laurae2
Copy link
Contributor

Laurae2 commented Jun 8, 2018

@randxie You can base the charts on the following R package if you want a baseline to compare with xgboost: https://github.com/dmlc/xgboost/tree/master/R-package/R

https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8017582 is also extremely interesting and go way beyond what LightGBM currently has as @zkurtz pointed out earlier..

Interactive link: http://shixialiu.com:8082/static/index.html

image

@StrikerRUS
Copy link
Collaborator

Closed in favor of being in #2302. We decided to keep all feature requests in one place.

Welcome to contribute this feature! Please re-open this issue (or post a comment if you are not a topic starter) if you are actively working on implementing this feature.

@jameslamb
Copy link
Collaborator

Adding this comment to this discussion: #3188 (comment)

hey @randxie are you still interested in trying this 😛

@SpeckledJim2
Copy link

I wrote an R function called lgb.plot.tree to graph a single tree from a LightGBM model, along similar lines to XGBoost's xgb.plot.multi.trees (but my function at the moment only plots a single LightGBM tree).

The function uses the DiagrammeR package (like XGBoost) to draw the tree graph.

Three examples in the code at the bottom of this message and I have also attached the agaricus mushroom dataset in .csv format which is used by the code to graph an example tree based on a LightGBM using categorical features.

I think there is already a tree diagram function for the Python LightGBM package - if there are already some tests for that function then it would probably make sense to use similar tests for this function - happy to do the work to set up these tests.

mushroom.csv

# libraries
library(lightgbm)
library(DiagrammeR)
library(data.table)
library(titanic)

# function to plot a single LightGBM tree using DiagrammeR
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL){
  # check model is lgb.Booster
  if (!inherits(model, "lgb.Booster")) {
    stop("model: Has to be an object of class lgb.Booster")
  }
  # check DiagrammeR is available
  if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
    stop("DiagrammeR package is required for lgb.plot.tree", 
         call. = FALSE)
  }
  # tree must be numeric
  if(!inherits(tree, 'numeric')){
    stop("tree: Has to be an integer numeric")
  }
  # tree must be integer
  if(tree%%1!=0){
    stop("tree: Has to be an integer numeric")
  }
  # extract data.table model structure
  dt <- lgb.model.dt.tree(model)
  # check that tree is less than or equal to the maximum tree index in the model
  if(tree>max(dt$tree_index)){
    stop("tree: has to be less than the number of trees in the model")
  }
  # filter dt to just the rows for the selected tree
  dt <- dt[tree_index==tree,]
  # change the column names to shorter more diagram friendly versions
  data.table::setnames(dt, old = c('tree_index','split_feature','threshold','split_gain'), new = c('Tree','Feature','Split','Gain'))
  dt[, Value:=0.0]
  dt[, Value:= leaf_value]
  dt[is.na(Value), Value := internal_value]
  dt[is.na(Gain), Gain := leaf_value]
  dt[is.na(Feature), Feature := 'Leaf']
  dt[, Cover := internal_count][Feature=='Leaf', Cover := leaf_count]
  dt[, c('leaf_count', 'internal_count','leaf_value','internal_value'):= NULL]
  dt[, Node := split_index]
  max_node <- max(dt[['Node']], na.rm = TRUE)
  dt[is.na(Node), Node := max_node + leaf_index +1]
  dt[, ID := paste(Tree, Node, sep = '-')]
  dt[, c('depth','leaf_index') := NULL]
  dt[, parent := node_parent][is.na(parent), parent := leaf_parent]
  dt[, c('node_parent', 'leaf_parent','split_index') := NULL]
  dt[, Yes := dt$ID[match(dt$Node, dt$parent)]]
  dt <- dt[nrow(dt):1,]
  dt[, No := dt$ID[match(dt$Node, dt$parent)]]
  # which way do the NA's go (this path will get a thicker arrow)
  # for categorical features, NA gets put into the zero group
  dt[default_left==TRUE, Missing := Yes]
  dt[default_left==FALSE, Missing := No]
  zero_present <- function(x){sapply(strsplit(as.character(x),'||',fixed = TRUE), function(el){any(el=='0')})}
  dt[zero_present(Split), Missing := Yes]
  #dt[, c('parent', 'default_left') := NULL]
  #data.table::setcolorder(dt, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value'))
  # create the label text
  dt[, label:= paste0(Feature,
                      "\nCover: ", Cover,
                      ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf", "", round(Gain, 4)),
                      "\nValue: ", round(Value, 4)
  )]
  # style the nodes - same format as xgboost
  dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
  dt[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"]
  dt[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"]
  # in order to draw the first tree on top:
  dt <- dt[order(-Tree)]
  nodes <- DiagrammeR::create_node_df(
    n         = nrow(dt),
    ID        = dt$ID,
    label     = dt$label,
    fillcolor = dt$filledcolor,
    shape     = dt$shape,
    data      = dt$Feature,
    fontcolor = "black")
  # round the edge labels to 4 s.f. if they are numeric
  # as otherwise get too many decimal places and the diagram looks bad
  # would rather not use suppressWarnings
  numeric_idx <- suppressWarnings(!is.na(as.numeric(dt[['Split']])))
  dt[numeric_idx, Split := round(as.numeric(Split),4)]
  # replace indices with feature levels if rules supplied
  levels.to.names <- function(x, feature_name, rules){
    lvls <- sort(rules[[feature_name]])
    result <- strsplit(x,'||', fixed = TRUE)
    result <- lapply(result, as.numeric)
    levels_to_names <- function(x){names(lvls)[as.numeric(x)]}
    result <- lapply(result, levels_to_names)
    result <- lapply(result, paste, collapse = '\n')
    result <- as.character(result)
  }
  if(!is.null(rules)){
    for (f in names(rules)){
      dt[Feature==f & decision_type == '==', Split := levels.to.names(Split, f, rules)]
    }
  }
  # replace long split names with a message
  dt[nchar(Split)>500, Split := 'Split too long to render']
  # create the edge labels
  edges <- DiagrammeR::create_edge_df(
    from  = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
    to    = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
    label = dt[Feature != "Leaf", paste(decision_type, Split)] %>%
      c(rep("", nrow(dt[Feature != "Leaf"]))),
    style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
      c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
    rel   = "leading_to")
  # create the graph
  graph <- DiagrammeR::create_graph(
    nodes_df = nodes,
    edges_df = edges,
    attr_theme = NULL
  ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "graph",
      attr  = c("layout", "rankdir"),
      value = c("dot", "LR")
    ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "node",
      attr  = c("color", "style", "fontname"),
      value = c("DimGray", "filled", "Helvetica")
    ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "edge",
      attr  = c("color", "arrowsize", "arrowhead", "fontname"),
      value = c("DimGray", "1.5", "vee", "Helvetica"))
  # render the graph
  DiagrammeR::render_graph(graph)
}

# EXAMPLE 1: use the LightGBM example dataset to build a model with a single tree
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
# define model parameters and build a single tree
params <- list(
  objective = "regression"
  , metric = "l2"
  , min_data = 1L
  , learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
  params = params
  , data = dtrain
  , nrounds = 1L
  , valids = valids
  , early_stopping_rounds = 1L
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
lgb.plot.tree(model, 0)

# EXAMPLE 2: same mushroom dataset, but model using categorical features
# change class to a 0/1 numeric
# deliberately set odor level a to NA to test thicker arrows for NA path
mushroom <- fread('mushroom.csv', stringsAsFactors = TRUE)
mushroom[odor=='a', odor := NA]
mushroom[, class := ifelse(class=='p',1,0)]
cat_cols <- setdiff(names(mushroom), 'class')
# split into train and test, label and data
set.seed(42)
train_rows <- sample(1:nrow(mushroom), 4062, replace = FALSE)
test_rows <- setdiff(1:nrow(mushroom), train_rows)
train_response <- mushroom[['class']][train_rows]
test_response <- mushroom[['class']][test_rows]
d_converted <- lgb.convert_with_rules(data = mushroom[, ..cat_cols])
d_train <- d_converted$data[train_rows]
d_test <- d_converted$data[test_rows]
# create train and test datasets for LGBM
l_train <- lgb.Dataset(as.matrix(d_train), label=train_response, categorical_feature = cat_cols)
l_test <- lgb.Dataset.create.valid(l_train, as.matrix(d_test), label = test_response)
# build the model
model_2 <- lgb.train(params = params
                     , nrounds = 1L
                     , data = l_train
                     , valids = list('test'=l_test)
                     , categorical_feature = cat_cols
                     , early_stopping_rounds = 1L
                     )
# plot some trees with and without using the rules
# and compare to tree table
tree_table_2 <- lgb.model.dt.tree(model_2)
lgb.plot.tree(model_2, 0, rules = NULL) # this will use the integer encodings
lgb.plot.tree(model_2, 0, rules = d_converted$rules) # this replaces the encodings with level description

# EXAMPLE 3: titanic
# only titanic_train contains the Survived column, so split that into train and test
train_rows <- 1:700
test_rows <- 700:nrow(titanic_train)
train_response <- titanic_train[['Survived']][train_rows]
test_response <- titanic_train[['Survived']][test_rows]
cols <- c('Pclass','Sex','Age','SibSp','Parch','Fare','Embarked')
d_converted <- lgb.convert_with_rules(data = titanic_train[cols])
d_train <- d_converted$data[train_rows,]
d_test <- d_converted$data[test_rows,]
# create train and test datasets for LGBM
l_train <- lgb.Dataset(as.matrix(d_train), label=train_response, categorical_feature = c('Sex','Embarked'))
l_test <- lgb.Dataset.create.valid(l_train, as.matrix(d_test), label = test_response)
# build the model
params_titanic <- list(
  objective = "binary"
  , metric = "binary_logloss"
  , min_data = 1L
  , learning_rate = 0.3
  , num_leaves = 10
)

model_3 <- lgb.train(params = params_titanic
                     , nrounds = 100L
                     , data = l_train
                     , valids = list('test'=l_test)
                     , early_stopping_rounds = 10L
)
# plot some trees with and without using the rules
# and compare to tree table
tree_table_3 <- lgb.model.dt.tree(model_3)
lgb.plot.tree(model_3, 0, rules = d_converted$rules)
lgb.plot.tree(model_3, 1, rules = d_converted$rules)
lgb.plot.tree(model_3, 3, rules = d_converted$rules)
lgb.plot.tree(model_3, 4, rules = d_converted$rules)
# lgb.plot.tree(model_3, 999, rules = d_converted$rules) # deliberately greater than number of trees - will throw an error

@jameslamb
Copy link
Collaborator

Thanks very much for your hard work @SpeckledJim2 ! As long as we can make {DiagrammeR} an optional dependency (as you've done with the use of requireNamespace()), I'm open to including functionality like this in the R package.

Please open a pull request which adds this function in a new file R-package/R/lgb.plot.tree.R. At a minimum, please:

  • add {DiagrammeR} as a Suggests-level dependency in R-package/DESCRIPTION
  • add lgb.plot.tree() to R-package/pkgdown/_pkgdown.yml
  • add {roxygen2} documentation and regenerate the package docs by runnings the following
    •   sh build-cran-package.sh --no-build-vignettes
        R CMD INSTALL --with-keep.source ./lightgbm_3.3.2.99.tar.gz
        cd R-package
        Rscript -e "roxygen2::roxygenize(load = 'installed')"
  • add a new file R-package/tests/testthat/test_lgb.plot.tree.R with at least one unit test on this functionality.

We can use the pull request for further feedback and suggestions.

@github-actions

This comment was marked as off-topic.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 15, 2023
@jameslamb
Copy link
Collaborator

This was locked accidentally. I just unlocked it. We'd still welcome contributions related to this feature!

@microsoft microsoft unlocked this conversation Aug 18, 2023
@yuanqingye
Copy link

I wrote an R function called lgb.plot.tree to graph a single tree from a LightGBM model, along similar lines to XGBoost's xgb.plot.multi.trees (but my function at the moment only plots a single LightGBM tree).

The function uses the DiagrammeR package (like XGBoost) to draw the tree graph.

Three examples in the code at the bottom of this message and I have also attached the agaricus mushroom dataset in .csv format which is used by the code to graph an example tree based on a LightGBM using categorical features.

I think there is already a tree diagram function for the Python LightGBM package - if there are already some tests for that function then it would probably make sense to use similar tests for this function - happy to do the work to set up these tests.

mushroom.csv

# libraries
library(lightgbm)
library(DiagrammeR)
library(data.table)
library(titanic)

# function to plot a single LightGBM tree using DiagrammeR
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL){
  # check model is lgb.Booster
  if (!inherits(model, "lgb.Booster")) {
    stop("model: Has to be an object of class lgb.Booster")
  }
  # check DiagrammeR is available
  if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
    stop("DiagrammeR package is required for lgb.plot.tree", 
         call. = FALSE)
  }
  # tree must be numeric
  if(!inherits(tree, 'numeric')){
    stop("tree: Has to be an integer numeric")
  }
  # tree must be integer
  if(tree%%1!=0){
    stop("tree: Has to be an integer numeric")
  }
  # extract data.table model structure
  dt <- lgb.model.dt.tree(model)
  # check that tree is less than or equal to the maximum tree index in the model
  if(tree>max(dt$tree_index)){
    stop("tree: has to be less than the number of trees in the model")
  }
  # filter dt to just the rows for the selected tree
  dt <- dt[tree_index==tree,]
  # change the column names to shorter more diagram friendly versions
  data.table::setnames(dt, old = c('tree_index','split_feature','threshold','split_gain'), new = c('Tree','Feature','Split','Gain'))
  dt[, Value:=0.0]
  dt[, Value:= leaf_value]
  dt[is.na(Value), Value := internal_value]
  dt[is.na(Gain), Gain := leaf_value]
  dt[is.na(Feature), Feature := 'Leaf']
  dt[, Cover := internal_count][Feature=='Leaf', Cover := leaf_count]
  dt[, c('leaf_count', 'internal_count','leaf_value','internal_value'):= NULL]
  dt[, Node := split_index]
  max_node <- max(dt[['Node']], na.rm = TRUE)
  dt[is.na(Node), Node := max_node + leaf_index +1]
  dt[, ID := paste(Tree, Node, sep = '-')]
  dt[, c('depth','leaf_index') := NULL]
  dt[, parent := node_parent][is.na(parent), parent := leaf_parent]
  dt[, c('node_parent', 'leaf_parent','split_index') := NULL]
  dt[, Yes := dt$ID[match(dt$Node, dt$parent)]]
  dt <- dt[nrow(dt):1,]
  dt[, No := dt$ID[match(dt$Node, dt$parent)]]
  # which way do the NA's go (this path will get a thicker arrow)
  # for categorical features, NA gets put into the zero group
  dt[default_left==TRUE, Missing := Yes]
  dt[default_left==FALSE, Missing := No]
  zero_present <- function(x){sapply(strsplit(as.character(x),'||',fixed = TRUE), function(el){any(el=='0')})}
  dt[zero_present(Split), Missing := Yes]
  #dt[, c('parent', 'default_left') := NULL]
  #data.table::setcolorder(dt, c('Tree','Node','ID','Feature','decision_type','Split','Yes','No','Missing','Gain','Cover','Value'))
  # create the label text
  dt[, label:= paste0(Feature,
                      "\nCover: ", Cover,
                      ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf", "", round(Gain, 4)),
                      "\nValue: ", round(Value, 4)
  )]
  # style the nodes - same format as xgboost
  dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
  dt[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"]
  dt[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"]
  # in order to draw the first tree on top:
  dt <- dt[order(-Tree)]
  nodes <- DiagrammeR::create_node_df(
    n         = nrow(dt),
    ID        = dt$ID,
    label     = dt$label,
    fillcolor = dt$filledcolor,
    shape     = dt$shape,
    data      = dt$Feature,
    fontcolor = "black")
  # round the edge labels to 4 s.f. if they are numeric
  # as otherwise get too many decimal places and the diagram looks bad
  # would rather not use suppressWarnings
  numeric_idx <- suppressWarnings(!is.na(as.numeric(dt[['Split']])))
  dt[numeric_idx, Split := round(as.numeric(Split),4)]
  # replace indices with feature levels if rules supplied
  levels.to.names <- function(x, feature_name, rules){
    lvls <- sort(rules[[feature_name]])
    result <- strsplit(x,'||', fixed = TRUE)
    result <- lapply(result, as.numeric)
    levels_to_names <- function(x){names(lvls)[as.numeric(x)]}
    result <- lapply(result, levels_to_names)
    result <- lapply(result, paste, collapse = '\n')
    result <- as.character(result)
  }
  if(!is.null(rules)){
    for (f in names(rules)){
      dt[Feature==f & decision_type == '==', Split := levels.to.names(Split, f, rules)]
    }
  }
  # replace long split names with a message
  dt[nchar(Split)>500, Split := 'Split too long to render']
  # create the edge labels
  edges <- DiagrammeR::create_edge_df(
    from  = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
    to    = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
    label = dt[Feature != "Leaf", paste(decision_type, Split)] %>%
      c(rep("", nrow(dt[Feature != "Leaf"]))),
    style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
      c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
    rel   = "leading_to")
  # create the graph
  graph <- DiagrammeR::create_graph(
    nodes_df = nodes,
    edges_df = edges,
    attr_theme = NULL
  ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "graph",
      attr  = c("layout", "rankdir"),
      value = c("dot", "LR")
    ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "node",
      attr  = c("color", "style", "fontname"),
      value = c("DimGray", "filled", "Helvetica")
    ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "edge",
      attr  = c("color", "arrowsize", "arrowhead", "fontname"),
      value = c("DimGray", "1.5", "vee", "Helvetica"))
  # render the graph
  DiagrammeR::render_graph(graph)
}

# EXAMPLE 1: use the LightGBM example dataset to build a model with a single tree
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
# define model parameters and build a single tree
params <- list(
  objective = "regression"
  , metric = "l2"
  , min_data = 1L
  , learning_rate = 1.0
)
valids <- list(test = dtest)
model <- lgb.train(
  params = params
  , data = dtrain
  , nrounds = 1L
  , valids = valids
  , early_stopping_rounds = 1L
)
# plot the tree and compare to the tree table
# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
lgb.plot.tree(model, 0)

# EXAMPLE 2: same mushroom dataset, but model using categorical features
# change class to a 0/1 numeric
# deliberately set odor level a to NA to test thicker arrows for NA path
mushroom <- fread('mushroom.csv', stringsAsFactors = TRUE)
mushroom[odor=='a', odor := NA]
mushroom[, class := ifelse(class=='p',1,0)]
cat_cols <- setdiff(names(mushroom), 'class')
# split into train and test, label and data
set.seed(42)
train_rows <- sample(1:nrow(mushroom), 4062, replace = FALSE)
test_rows <- setdiff(1:nrow(mushroom), train_rows)
train_response <- mushroom[['class']][train_rows]
test_response <- mushroom[['class']][test_rows]
d_converted <- lgb.convert_with_rules(data = mushroom[, ..cat_cols])
d_train <- d_converted$data[train_rows]
d_test <- d_converted$data[test_rows]
# create train and test datasets for LGBM
l_train <- lgb.Dataset(as.matrix(d_train), label=train_response, categorical_feature = cat_cols)
l_test <- lgb.Dataset.create.valid(l_train, as.matrix(d_test), label = test_response)
# build the model
model_2 <- lgb.train(params = params
                     , nrounds = 1L
                     , data = l_train
                     , valids = list('test'=l_test)
                     , categorical_feature = cat_cols
                     , early_stopping_rounds = 1L
                     )
# plot some trees with and without using the rules
# and compare to tree table
tree_table_2 <- lgb.model.dt.tree(model_2)
lgb.plot.tree(model_2, 0, rules = NULL) # this will use the integer encodings
lgb.plot.tree(model_2, 0, rules = d_converted$rules) # this replaces the encodings with level description

# EXAMPLE 3: titanic
# only titanic_train contains the Survived column, so split that into train and test
train_rows <- 1:700
test_rows <- 700:nrow(titanic_train)
train_response <- titanic_train[['Survived']][train_rows]
test_response <- titanic_train[['Survived']][test_rows]
cols <- c('Pclass','Sex','Age','SibSp','Parch','Fare','Embarked')
d_converted <- lgb.convert_with_rules(data = titanic_train[cols])
d_train <- d_converted$data[train_rows,]
d_test <- d_converted$data[test_rows,]
# create train and test datasets for LGBM
l_train <- lgb.Dataset(as.matrix(d_train), label=train_response, categorical_feature = c('Sex','Embarked'))
l_test <- lgb.Dataset.create.valid(l_train, as.matrix(d_test), label = test_response)
# build the model
params_titanic <- list(
  objective = "binary"
  , metric = "binary_logloss"
  , min_data = 1L
  , learning_rate = 0.3
  , num_leaves = 10
)

model_3 <- lgb.train(params = params_titanic
                     , nrounds = 100L
                     , data = l_train
                     , valids = list('test'=l_test)
                     , early_stopping_rounds = 10L
)
# plot some trees with and without using the rules
# and compare to tree table
tree_table_3 <- lgb.model.dt.tree(model_3)
lgb.plot.tree(model_3, 0, rules = d_converted$rules)
lgb.plot.tree(model_3, 1, rules = d_converted$rules)
lgb.plot.tree(model_3, 3, rules = d_converted$rules)
lgb.plot.tree(model_3, 4, rules = d_converted$rules)
# lgb.plot.tree(model_3, 999, rules = d_converted$rules) # deliberately greater than number of trees - will throw an error

This seems work for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

10 participants