Skip to content

Commit

Permalink
Trees with linear models at leaves (#3299)
Browse files Browse the repository at this point in the history
* Add Eigen library.

* Working for simple test.

* Apply changes to config params.

* Handle nan data.

* Update docs.

* Add test.

* Only load raw data if boosting=gbdt_linear

* Remove unneeded code.

* Minor updates.

* Update to work with sk-learn interface.

* Update to work with chunked datasets.

* Throw error if we try to create a Booster with an already-constructed dataset having incompatible parameters.

* Save raw data in binary dataset file.

* Update docs and fix parameter checking.

* Fix dataset loading.

* Add test for regularization.

* Fix bugs when saving and loading tree.

* Add test for load/save linear model.

* Remove unneeded code.

* Fix case where not enough leaf data for linear model.

* Simplify code.

* Speed up code.

* Speed up code.

* Simplify code.

* Speed up code.

* Fix bugs.

* Working version.

* Store feature data column-wise (not fully working yet).

* Fix bugs.

* Speed up.

* Speed up.

* Remove unneeded code.

* Small speedup.

* Speed up.

* Minor updates.

* Remove unneeded code.

* Fix bug.

* Fix bug.

* Speed up.

* Speed up.

* Simplify code.

* Remove unneeded code.

* Fix bug, add more tests.

* Fix bug and add test.

* Only store numerical features

* Fix bug and speed up using templates.

* Speed up prediction.

* Fix bug with regularisation

* Visual studio files.

* Working version

* Only check nans if necessary

* Store coeff matrix as an array.

* Align cache lines

* Align cache lines

* Preallocation coefficient calculation matrices

* Small speedups

* Small speedup

* Reverse cache alignment changes

* Change to dynamic schedule

* Update docs.

* Refactor so that linear tree learner is not a separate class.

* Add refit capability.

* Speed up

* Small speedups.

* Speed up add prediction to score.

* Fix bug

* Fix bug and speed up.

* Speed up dataload.

* Speed up dataload

* Use vectors instead of pointers

* Fix bug

* Add OMP exception handling.

* Change return type of LGBM_BoosterGetLinear to bool

* Change return type of LGBM_BoosterGetLinear back to int, only parameter type needed to change

* Remove unused internal_parent_ property of tree

* Remove unused parameter to CreateTreeLearner

* Remove reference to LinearTreeLearner

* Minor style issues

* Remove unneeded check

* Reverse temporary testing change

* Fix Visual Studio project files

* Restore LightGBM.vcxproj.filters

* Speed up

* Speed up

* Simplify code

* Update docs

* Simplify code

* Initialise storage space for max num threads

* Move Eigen to include directory and delete unused files

* Remove old files.

* Fix so it compiles with mingw

* Fix gpu tree learner

* Change AddPredictionToScore back to const

* Fix python lint error

* Fix C++ lint errors

* Change eigen to a submodule

* Update comment

* Add the eigen folder

* Try to fix build issues with eigen

* Remove eigen files

* Add eigen as submodule

* Fix include paths

* Exclude eigen files from Python linter

* Ignore eigen folders for pydocstyle

* Fix C++ linting errors

* Fix docs

* Fix docs

* Exclude eigen directories from doxygen

* Update manifest to include eigen

* Update build_r to include eigen files

* Fix compiler warnings

* Store raw feature data as float

* Use float for calculating linear coefficients

* Remove eigen directory from GLOB

* Don't compile linear model code when building R package

* Fix doxygen issue

* Fix lint issue

* Fix lint issue

* Remove uneeded code

* Restore delected lines

* Restore delected lines

* Change return type of has_raw to bool

* Update docs

* Rename some variables and functions for readability

* Make tree_learner parameter const in AddScore

* Fix style issues

* Pass vectors as const reference when setting tree properties

* Make temporary storage of serial_tree_learner mutable so we can make the object's methods const

* Remove get_raw_size, use num_numeric_features instead

* Fix typo

* Make contains_nan_ and any_nan_ properties immutable again

* Remove data_has_nan_ property of tree

* Remove temporary test code

* Make linear_tree a dataset param

* Fix lint error

* Make LinearTreeLearner a separate class

* Fix lint errors

* Fix lint error

* Add linear_tree_learner.o

* Simulate omp_get_max_threads if openmp is not available

* Update PushOneData to also store raw data.

* Cast size to int

* Fix bug in ReshapeRaw

* Speed up code with multithreading

* Use OMP_NUM_THREADS

* Speed up with multithreading

* Update to use ArrayToString

* Fix tests

* Fix test

* Fix bug introduced in merge

* Minor updates

* Update docs
  • Loading branch information
btrotta authored Dec 24, 2020
1 parent d90a16d commit fcfd413
Show file tree
Hide file tree
Showing 41 changed files with 1,538 additions and 96 deletions.
4 changes: 2 additions & 2 deletions .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ if [[ $TRAVIS == "true" ]] && [[ $TASK == "lint" ]]; then
"r-lintr>=2.0"
pip install --user cpplint
echo "Linting Python code"
pycodestyle --ignore=E501,W503 --exclude=./compute,./.nuget,./external_libs . || exit -1
pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
pycodestyle --ignore=E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1
pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|^eigen|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
echo "Linting R code"
Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1
echo "Linting C++ code"
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[submodule "include/boost/compute"]
path = compute
url = https://github.com/boostorg/compute
[submodule "eigen"]
path = eigen
url = https://gitlab.com/libeigen/eigen.git
[submodule "external_libs/fmt"]
path = external_libs/fmt
url = https://github.com/fmtlib/fmt.git
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ if(USE_SWIG)
endif()
endif(USE_SWIG)

SET(EIGEN_DIR "${PROJECT_SOURCE_DIR}/eigen")
include_directories(${EIGEN_DIR})

if(__BUILD_FOR_R)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules")
find_package(LibR REQUIRED)
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ OBJECTS = \
treelearner/data_parallel_tree_learner.o \
treelearner/feature_parallel_tree_learner.o \
treelearner/gpu_tree_learner.o \
treelearner/linear_tree_learner.o \
treelearner/serial_tree_learner.o \
treelearner/tree_learner.o \
treelearner/voting_parallel_tree_learner.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win.in
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ OBJECTS = \
treelearner/data_parallel_tree_learner.o \
treelearner/feature_parallel_tree_learner.o \
treelearner/gpu_tree_learner.o \
treelearner/linear_tree_learner.o \
treelearner/serial_tree_learner.o \
treelearner/tree_learner.o \
treelearner/voting_parallel_tree_learner.o \
Expand Down
24 changes: 24 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ Core Parameters

- **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations

- ``linear_tree`` :raw-html:`<a id="linear_tree" title="Permalink to this parameter" href="#linear_tree">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool

- fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner

- tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant

- the linear model at each leaf includes all the numerical features in that leaf's branch

- categorical features are used for splits as normal but are not used in the linear models

- missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0``

- it is recommended to rescale data before training so that features have similar mean and standard deviation

- not yet supported in R-package

- ``regression_l1`` objective is not supported with linear tree boosting

- setting ``linear_tree = True`` significantly increases the memory use of LightGBM

- ``data`` :raw-html:`<a id="data" title="Permalink to this parameter" href="#data">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string, aliases: ``train``, ``train_data``, ``train_data_file``, ``data_filename``

- path of training data, LightGBM will train from this data
Expand Down Expand Up @@ -384,6 +404,10 @@ Learning Control Parameters

- L2 regularization

- ``linear_lambda`` :raw-html:`<a id="linear_lambda" title="Permalink to this parameter" href="#linear_lambda">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, constraints: ``linear_lambda >= 0.0``

- Linear tree regularisation, the parameter `lambda` in Eq 3 of <https://arxiv.org/pdf/1802.05640.pdf>

- ``min_gain_to_split`` :raw-html:`<a id="min_gain_to_split" title="Permalink to this parameter" href="#min_gain_to_split">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``min_split_gain``, constraints: ``min_gain_to_split >= 0.0``

- the minimal gain to perform split
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def generate_doxygen_xml(app):
"SKIP_FUNCTION_MACROS=NO",
"SORT_BRIEF_DOCS=YES",
"WARN_AS_ERROR=YES",
"EXCLUDE_PATTERNS=*/eigen/*"
]
doxygen_input = '\n'.join(doxygen_args)
doxygen_input = bytes(doxygen_input, "utf-8")
Expand Down
1 change: 1 addition & 0 deletions eigen
Submodule eigen added at 8ba1b0
113 changes: 113 additions & 0 deletions examples/binary_classification/train_linear.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# task type, support train and predict
task = train

# boosting type, support gbdt for now, alias: boosting, boost
boosting_type = gbdt

# application type, support following application
# regression , regression task
# binary , binary classification task
# lambdarank , lambdarank task
# alias: application, app
objective = binary

linear_tree = true

# eval metrics, support multi metric, delimite by ',' , support following metrics
# l1
# l2 , default metric for regression
# ndcg , default metric for lambdarank
# auc
# binary_logloss , default metric for binary
# binary_error
metric = binary_logloss,auc

# frequence for metric output
metric_freq = 1

# true if need output metric for training data, alias: tranining_metric, train_metric
is_training_metric = true

# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy.
max_bin = 255

# training data
# if exsting weight file, should name to "binary.train.weight"
# alias: train_data, train
data = binary.train

# validation data, support multi validation data, separated by ','
# if exsting weight file, should name to "binary.test.weight"
# alias: valid, test, test_data,
valid_data = binary.test

# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds
num_trees = 100

# shrinkage rate , alias: shrinkage_rate
learning_rate = 0.1

# number of leaves for one tree, alias: num_leaf
num_leaves = 63

# type of tree learner, support following types:
# serial , single machine version
# feature , use feature parallel to train
# data , use data parallel to train
# voting , use voting based parallel to train
# alias: tree
tree_learner = serial

# number of threads for multi-threading. One thread will use one CPU, defalut is setted to #cpu.
# num_threads = 8

# feature sub-sample, will random select 80% feature to train on each iteration
# alias: sub_feature
feature_fraction = 0.8

# Support bagging (data sub-sample), will perform bagging every 5 iterations
bagging_freq = 5

# Bagging farction, will random select 80% data on bagging
# alias: sub_row
bagging_fraction = 0.8

# minimal number data for one leaf, use this to deal with over-fit
# alias : min_data_per_leaf, min_data
min_data_in_leaf = 50

# minimal sum hessians for one leaf, use this to deal with over-fit
min_sum_hessian_in_leaf = 5.0

# save memory and faster speed for sparse feature, alias: is_sparse
is_enable_sparse = true

# when data is bigger than memory size, set this to true. otherwise set false will have faster speed
# alias: two_round_loading, two_round
use_two_round_loading = false

# true if need to save data to binary file and application will auto load data from binary file next time
# alias: is_save_binary, save_binary
is_save_binary_file = false

# output model file
output_model = LightGBM_model.txt

# support continuous train from trained gbdt model
# input_model= trained_model.txt

# output prediction file for predict task
# output_result= prediction.txt


# number of machines in parallel training, alias: num_machine
num_machines = 1

# local listening port in parallel training, alias: local_port
local_listen_port = 12400

# machines list file for parallel training, alias: mlist
machine_list_file = mlist.txt

# force splits
# forced_splits = forced_splits.json
2 changes: 2 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ class LIGHTGBM_EXPORT Boosting {
* \return The boosting object
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);

virtual bool IsLinear() const { return false; }
};

class GBDTBase : public Boosting {
Expand Down
8 changes: 8 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out);

/*!
* \brief Get boolean representing whether booster is fitting linear trees.
* \param handle Handle of dataset
* \param[out] out The address to hold linear indicator
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out);

/*!
* \brief Add features from ``source`` to ``target``.
* \param target The handle of the dataset to add features to
Expand Down
15 changes: 15 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ struct Config {
// descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations
std::string boosting = "gbdt";

// desc = fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner
// descl2 = tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant
// descl2 = the linear model at each leaf includes all the numerical features in that leaf's branch
// descl2 = categorical features are used for splits as normal but are not used in the linear models
// descl2 = missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0``
// descl2 = it is recommended to rescale data before training so that features have similar mean and standard deviation
// descl2 = not yet supported in R-package
// descl2 = ``regression_l1`` objective is not supported with linear tree boosting
// descl2 = setting ``linear_tree = True`` significantly increases the memory use of LightGBM
bool linear_tree = false;

// alias = train, train_data, train_data_file, data_filename
// desc = path of training data, LightGBM will train from this data
// desc = **Note**: can be used only in CLI version
Expand Down Expand Up @@ -366,6 +377,10 @@ struct Config {
// desc = L2 regularization
double lambda_l2 = 0.0;

// check = >=0.0
// desc = Linear tree regularisation, the parameter `lambda` in Eq 3 of <https://arxiv.org/pdf/1802.05640.pdf>
double linear_lambda = 0.0;

// alias = min_split_gain
// check = >=0.0
// desc = the minimal gain to perform split
Expand Down
53 changes: 52 additions & 1 deletion include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ class Dataset {
const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, feature_values[i]);
if (has_raw_) {
int feat_ind = numeric_feature_map_[feature_idx];
if (feat_ind >= 0) {
raw_data_[feat_ind][row_idx] = feature_values[i];
}
}
}
}
}
Expand All @@ -352,13 +358,25 @@ class Dataset {
const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, inner_data.second);
if (has_raw_) {
int feat_ind = numeric_feature_map_[feature_idx];
if (feat_ind >= 0) {
raw_data_[feat_ind][row_idx] = inner_data.second;
}
}
}
}
FinishOneRow(tid, row_idx, is_feature_added);
}

inline void PushOneData(int tid, data_size_t row_idx, int group, int sub_feature, double value) {
inline void PushOneData(int tid, data_size_t row_idx, int group, int feature_idx, int sub_feature, double value) {
feature_groups_[group]->PushData(tid, sub_feature, row_idx, value);
if (has_raw_) {
int feat_ind = numeric_feature_map_[feature_idx];
if (feat_ind >= 0) {
raw_data_[feat_ind][row_idx] = value;
}
}
}

inline int RealFeatureIndex(int fidx) const {
Expand Down Expand Up @@ -569,6 +587,9 @@ class Dataset {
/*! \brief Get Number of used features */
inline int num_features() const { return num_features_; }

/*! \brief Get number of numeric features */
inline int num_numeric_features() const { return num_numeric_features_; }

/*! \brief Get Number of feature groups */
inline int num_feature_groups() const { return num_groups_;}

Expand Down Expand Up @@ -632,6 +653,31 @@ class Dataset {

void AddFeaturesFrom(Dataset* other);

/*! \brief Get has_raw_ */
inline bool has_raw() const { return has_raw_; }

/*! \brief Set has_raw_ */
inline void SetHasRaw(bool has_raw) { has_raw_ = has_raw; }

/*! \brief Resize raw_data_ */
inline void ResizeRaw(int num_rows) {
if (static_cast<int>(raw_data_.size()) > num_numeric_features_) {
raw_data_.resize(num_numeric_features_);
}
for (size_t i = 0; i < raw_data_.size(); ++i) {
raw_data_[i].resize(num_rows);
}
int curr_size = static_cast<int>(raw_data_.size());
for (int i = curr_size; i < num_numeric_features_; ++i) {
raw_data_.push_back(std::vector<float>(num_rows, 0));
}
}

/*! \brief Get pointer to raw_data_ feature */
inline const float* raw_index(int feat_ind) const {
return raw_data_[numeric_feature_map_[feat_ind]].data();
}

private:
std::string data_filename_;
/*! \brief Store used features */
Expand Down Expand Up @@ -668,6 +714,11 @@ class Dataset {
bool use_missing_;
bool zero_as_missing_;
std::vector<int> feature_need_push_zeros_;
std::vector<std::vector<float>> raw_data_;
bool has_raw_;
/*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
std::vector<int> numeric_feature_map_;
int num_numeric_features_;
};

} // namespace LightGBM
Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/dataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class DatasetLoader {
std::vector<std::string> feature_names_;
/*! \brief Mapper from real feature index to used index*/
std::unordered_set<int> categorical_features_;
/*! \brief Whether to store raw feature values */
bool store_raw_;
};

} // namespace LightGBM
Expand Down
Loading

2 comments on commit fcfd413

@gianpd
Copy link

@gianpd gianpd commented on fcfd413 Jan 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why linear_tree and linear_lambda are not know parameters in 3.1.1 version of python API?

@jameslamb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why linear_tree and linear_lambda are not know parameters in 3.1.1 version of python API?

Correct. This PR was merged v3.1.1 was released on December 7 (https://github.com/microsoft/LightGBM/releases/tag/v3.1.1) and this pull request was merged on December 24. This feature will be in a future release of LightGBM.

If you have additional questions, please open an issue.

Please sign in to comment.