diff --git a/.ci/test.sh b/.ci/test.sh
index 4cd181790ad7..776bf7aa0287 100755
--- a/.ci/test.sh
+++ b/.ci/test.sh
@@ -52,8 +52,8 @@ if [[ $TRAVIS == "true" ]] && [[ $TASK == "lint" ]]; then
"r-lintr>=2.0"
pip install --user cpplint
echo "Linting Python code"
- pycodestyle --ignore=E501,W503 --exclude=./compute,./.nuget,./external_libs . || exit -1
- pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
+ pycodestyle --ignore=E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1
+ pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|^eigen|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
echo "Linting R code"
Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1
echo "Linting C++ code"
diff --git a/.gitmodules b/.gitmodules
index 8f1772cd19fa..ccab67d4a1f5 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,6 +1,9 @@
[submodule "include/boost/compute"]
path = compute
url = https://github.com/boostorg/compute
+[submodule "eigen"]
+ path = eigen
+ url = https://gitlab.com/libeigen/eigen.git
[submodule "external_libs/fmt"]
path = external_libs/fmt
url = https://github.com/fmtlib/fmt.git
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c8559c966ab1..de0dd6573199 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -90,6 +90,9 @@ if(USE_SWIG)
endif()
endif(USE_SWIG)
+SET(EIGEN_DIR "${PROJECT_SOURCE_DIR}/eigen")
+include_directories(${EIGEN_DIR})
+
if(__BUILD_FOR_R)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules")
find_package(LibR REQUIRED)
diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in
index 0c580ebe36bf..2490ba0757df 100644
--- a/R-package/src/Makevars.in
+++ b/R-package/src/Makevars.in
@@ -48,6 +48,7 @@ OBJECTS = \
treelearner/data_parallel_tree_learner.o \
treelearner/feature_parallel_tree_learner.o \
treelearner/gpu_tree_learner.o \
+ treelearner/linear_tree_learner.o \
treelearner/serial_tree_learner.o \
treelearner/tree_learner.o \
treelearner/voting_parallel_tree_learner.o \
diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in
index 952508e16a7d..0fb2de926905 100644
--- a/R-package/src/Makevars.win.in
+++ b/R-package/src/Makevars.win.in
@@ -49,6 +49,7 @@ OBJECTS = \
treelearner/data_parallel_tree_learner.o \
treelearner/feature_parallel_tree_learner.o \
treelearner/gpu_tree_learner.o \
+ treelearner/linear_tree_learner.o \
treelearner/serial_tree_learner.o \
treelearner/tree_learner.o \
treelearner/voting_parallel_tree_learner.o \
diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 22e98217d614..1287c1354331 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -117,6 +117,26 @@ Core Parameters
- **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations
+- ``linear_tree`` :raw-html:`🔗︎`, default = ``false``, type = bool
+
+ - fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner
+
+ - tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant
+
+ - the linear model at each leaf includes all the numerical features in that leaf's branch
+
+ - categorical features are used for splits as normal but are not used in the linear models
+
+ - missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0``
+
+ - it is recommended to rescale data before training so that features have similar mean and standard deviation
+
+ - not yet supported in R-package
+
+ - ``regression_l1`` objective is not supported with linear tree boosting
+
+ - setting ``linear_tree = True`` significantly increases the memory use of LightGBM
+
- ``data`` :raw-html:`🔗︎`, default = ``""``, type = string, aliases: ``train``, ``train_data``, ``train_data_file``, ``data_filename``
- path of training data, LightGBM will train from this data
@@ -384,6 +404,10 @@ Learning Control Parameters
- L2 regularization
+- ``linear_lambda`` :raw-html:`🔗︎`, default = ``0.0``, type = double, constraints: ``linear_lambda >= 0.0``
+
+ - Linear tree regularisation, the parameter `lambda` in Eq 3 of
+
- ``min_gain_to_split`` :raw-html:`🔗︎`, default = ``0.0``, type = double, aliases: ``min_split_gain``, constraints: ``min_gain_to_split >= 0.0``
- the minimal gain to perform split
diff --git a/docs/conf.py b/docs/conf.py
index aa1cb2dc3610..e30566c6fc46 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -202,6 +202,7 @@ def generate_doxygen_xml(app):
"SKIP_FUNCTION_MACROS=NO",
"SORT_BRIEF_DOCS=YES",
"WARN_AS_ERROR=YES",
+ "EXCLUDE_PATTERNS=*/eigen/*"
]
doxygen_input = '\n'.join(doxygen_args)
doxygen_input = bytes(doxygen_input, "utf-8")
diff --git a/eigen b/eigen
new file mode 160000
index 000000000000..8ba1b0f41a79
--- /dev/null
+++ b/eigen
@@ -0,0 +1 @@
+Subproject commit 8ba1b0f41a7950dc3e1d4ed75859e36c73311235
diff --git a/examples/binary_classification/train_linear.conf b/examples/binary_classification/train_linear.conf
new file mode 100644
index 000000000000..616d5fc39e35
--- /dev/null
+++ b/examples/binary_classification/train_linear.conf
@@ -0,0 +1,113 @@
+# task type, support train and predict
+task = train
+
+# boosting type, support gbdt for now, alias: boosting, boost
+boosting_type = gbdt
+
+# application type, support following application
+# regression , regression task
+# binary , binary classification task
+# lambdarank , lambdarank task
+# alias: application, app
+objective = binary
+
+linear_tree = true
+
+# eval metrics, support multi metric, delimite by ',' , support following metrics
+# l1
+# l2 , default metric for regression
+# ndcg , default metric for lambdarank
+# auc
+# binary_logloss , default metric for binary
+# binary_error
+metric = binary_logloss,auc
+
+# frequence for metric output
+metric_freq = 1
+
+# true if need output metric for training data, alias: tranining_metric, train_metric
+is_training_metric = true
+
+# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy.
+max_bin = 255
+
+# training data
+# if exsting weight file, should name to "binary.train.weight"
+# alias: train_data, train
+data = binary.train
+
+# validation data, support multi validation data, separated by ','
+# if exsting weight file, should name to "binary.test.weight"
+# alias: valid, test, test_data,
+valid_data = binary.test
+
+# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds
+num_trees = 100
+
+# shrinkage rate , alias: shrinkage_rate
+learning_rate = 0.1
+
+# number of leaves for one tree, alias: num_leaf
+num_leaves = 63
+
+# type of tree learner, support following types:
+# serial , single machine version
+# feature , use feature parallel to train
+# data , use data parallel to train
+# voting , use voting based parallel to train
+# alias: tree
+tree_learner = serial
+
+# number of threads for multi-threading. One thread will use one CPU, defalut is setted to #cpu.
+# num_threads = 8
+
+# feature sub-sample, will random select 80% feature to train on each iteration
+# alias: sub_feature
+feature_fraction = 0.8
+
+# Support bagging (data sub-sample), will perform bagging every 5 iterations
+bagging_freq = 5
+
+# Bagging farction, will random select 80% data on bagging
+# alias: sub_row
+bagging_fraction = 0.8
+
+# minimal number data for one leaf, use this to deal with over-fit
+# alias : min_data_per_leaf, min_data
+min_data_in_leaf = 50
+
+# minimal sum hessians for one leaf, use this to deal with over-fit
+min_sum_hessian_in_leaf = 5.0
+
+# save memory and faster speed for sparse feature, alias: is_sparse
+is_enable_sparse = true
+
+# when data is bigger than memory size, set this to true. otherwise set false will have faster speed
+# alias: two_round_loading, two_round
+use_two_round_loading = false
+
+# true if need to save data to binary file and application will auto load data from binary file next time
+# alias: is_save_binary, save_binary
+is_save_binary_file = false
+
+# output model file
+output_model = LightGBM_model.txt
+
+# support continuous train from trained gbdt model
+# input_model= trained_model.txt
+
+# output prediction file for predict task
+# output_result= prediction.txt
+
+
+# number of machines in parallel training, alias: num_machine
+num_machines = 1
+
+# local listening port in parallel training, alias: local_port
+local_listen_port = 12400
+
+# machines list file for parallel training, alias: mlist
+machine_list_file = mlist.txt
+
+# force splits
+# forced_splits = forced_splits.json
diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h
index ca198d37671e..ddbcdbc18e44 100644
--- a/include/LightGBM/boosting.h
+++ b/include/LightGBM/boosting.h
@@ -312,6 +312,8 @@ class LIGHTGBM_EXPORT Boosting {
* \return The boosting object
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);
+
+ virtual bool IsLinear() const { return false; }
};
class GBDTBase : public Boosting {
diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h
index 521ea821eccd..c16143d82160 100644
--- a/include/LightGBM/c_api.h
+++ b/include/LightGBM/c_api.h
@@ -389,6 +389,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out);
+/*!
+* \brief Get boolean representing whether booster is fitting linear trees.
+* \param handle Handle of dataset
+* \param[out] out The address to hold linear indicator
+* \return 0 when succeed, -1 when failure happens
+*/
+LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out);
+
/*!
* \brief Add features from ``source`` to ``target``.
* \param target The handle of the dataset to add features to
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 8c017ad26926..43f8edd99bf8 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -148,6 +148,17 @@ struct Config {
// descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations
std::string boosting = "gbdt";
+ // desc = fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner
+ // descl2 = tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant
+ // descl2 = the linear model at each leaf includes all the numerical features in that leaf's branch
+ // descl2 = categorical features are used for splits as normal but are not used in the linear models
+ // descl2 = missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0``
+ // descl2 = it is recommended to rescale data before training so that features have similar mean and standard deviation
+ // descl2 = not yet supported in R-package
+ // descl2 = ``regression_l1`` objective is not supported with linear tree boosting
+ // descl2 = setting ``linear_tree = True`` significantly increases the memory use of LightGBM
+ bool linear_tree = false;
+
// alias = train, train_data, train_data_file, data_filename
// desc = path of training data, LightGBM will train from this data
// desc = **Note**: can be used only in CLI version
@@ -366,6 +377,10 @@ struct Config {
// desc = L2 regularization
double lambda_l2 = 0.0;
+ // check = >=0.0
+ // desc = Linear tree regularisation, the parameter `lambda` in Eq 3 of
+ double linear_lambda = 0.0;
+
// alias = min_split_gain
// check = >=0.0
// desc = the minimal gain to perform split
diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h
index 44c4196544b6..b1ee7a40dd40 100644
--- a/include/LightGBM/dataset.h
+++ b/include/LightGBM/dataset.h
@@ -337,6 +337,12 @@ class Dataset {
const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, feature_values[i]);
+ if (has_raw_) {
+ int feat_ind = numeric_feature_map_[feature_idx];
+ if (feat_ind >= 0) {
+ raw_data_[feat_ind][row_idx] = feature_values[i];
+ }
+ }
}
}
}
@@ -352,13 +358,25 @@ class Dataset {
const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, inner_data.second);
+ if (has_raw_) {
+ int feat_ind = numeric_feature_map_[feature_idx];
+ if (feat_ind >= 0) {
+ raw_data_[feat_ind][row_idx] = inner_data.second;
+ }
+ }
}
}
FinishOneRow(tid, row_idx, is_feature_added);
}
- inline void PushOneData(int tid, data_size_t row_idx, int group, int sub_feature, double value) {
+ inline void PushOneData(int tid, data_size_t row_idx, int group, int feature_idx, int sub_feature, double value) {
feature_groups_[group]->PushData(tid, sub_feature, row_idx, value);
+ if (has_raw_) {
+ int feat_ind = numeric_feature_map_[feature_idx];
+ if (feat_ind >= 0) {
+ raw_data_[feat_ind][row_idx] = value;
+ }
+ }
}
inline int RealFeatureIndex(int fidx) const {
@@ -569,6 +587,9 @@ class Dataset {
/*! \brief Get Number of used features */
inline int num_features() const { return num_features_; }
+ /*! \brief Get number of numeric features */
+ inline int num_numeric_features() const { return num_numeric_features_; }
+
/*! \brief Get Number of feature groups */
inline int num_feature_groups() const { return num_groups_;}
@@ -632,6 +653,31 @@ class Dataset {
void AddFeaturesFrom(Dataset* other);
+ /*! \brief Get has_raw_ */
+ inline bool has_raw() const { return has_raw_; }
+
+ /*! \brief Set has_raw_ */
+ inline void SetHasRaw(bool has_raw) { has_raw_ = has_raw; }
+
+ /*! \brief Resize raw_data_ */
+ inline void ResizeRaw(int num_rows) {
+ if (static_cast(raw_data_.size()) > num_numeric_features_) {
+ raw_data_.resize(num_numeric_features_);
+ }
+ for (size_t i = 0; i < raw_data_.size(); ++i) {
+ raw_data_[i].resize(num_rows);
+ }
+ int curr_size = static_cast(raw_data_.size());
+ for (int i = curr_size; i < num_numeric_features_; ++i) {
+ raw_data_.push_back(std::vector(num_rows, 0));
+ }
+ }
+
+ /*! \brief Get pointer to raw_data_ feature */
+ inline const float* raw_index(int feat_ind) const {
+ return raw_data_[numeric_feature_map_[feat_ind]].data();
+ }
+
private:
std::string data_filename_;
/*! \brief Store used features */
@@ -668,6 +714,11 @@ class Dataset {
bool use_missing_;
bool zero_as_missing_;
std::vector feature_need_push_zeros_;
+ std::vector> raw_data_;
+ bool has_raw_;
+ /*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
+ std::vector numeric_feature_map_;
+ int num_numeric_features_;
};
} // namespace LightGBM
diff --git a/include/LightGBM/dataset_loader.h b/include/LightGBM/dataset_loader.h
index fde3d1c00e16..e72dd4910804 100644
--- a/include/LightGBM/dataset_loader.h
+++ b/include/LightGBM/dataset_loader.h
@@ -82,6 +82,8 @@ class DatasetLoader {
std::vector feature_names_;
/*! \brief Mapper from real feature index to used index*/
std::unordered_set categorical_features_;
+ /*! \brief Whether to store raw feature values */
+ bool store_raw_;
};
} // namespace LightGBM
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index fe46dc4f313e..d5d9b5282059 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -28,8 +28,9 @@ class Tree {
* \brief Constructor
* \param max_leaves The number of max leaves
* \param track_branch_features Whether to keep track of ancestors of leaf nodes
+ * \param is_linear Whether the tree has linear models at each leaf
*/
- explicit Tree(int max_leaves, bool track_branch_features);
+ explicit Tree(int max_leaves, bool track_branch_features, bool is_linear);
/*!
* \brief Constructor, from a string
@@ -148,9 +149,12 @@ class Tree {
/*! \brief Get parent of specific leaf*/
inline int leaf_parent(int leaf_idx) const {return leaf_parent_[leaf_idx]; }
- /*! \brief Get feature of specific split*/
+ /*! \brief Get feature of specific split (original feature index)*/
inline int split_feature(int split_idx) const { return split_feature_[split_idx]; }
+ /*! \brief Get feature of specific split*/
+ inline int split_feature_inner(int split_idx) const { return split_feature_inner_[split_idx]; }
+
/*! \brief Get features on leaf's branch*/
inline std::vector branch_features(int leaf) const { return branch_features_[leaf]; }
@@ -168,10 +172,6 @@ class Tree {
inline int right_child(int node_idx) const { return right_child_[node_idx]; }
- inline int split_feature_inner(int node_idx) const {
- return split_feature_inner_[node_idx];
- }
-
inline uint32_t threshold_in_bin(int node_idx) const {
return threshold_in_bin_[node_idx];
}
@@ -189,9 +189,21 @@ class Tree {
for (int i = 0; i < num_leaves_ - 1; ++i) {
leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] * rate);
internal_value_[i] = MaybeRoundToZero(internal_value_[i] * rate);
+ if (is_linear_) {
+ leaf_const_[i] = MaybeRoundToZero(leaf_const_[i] * rate);
+ for (size_t j = 0; j < leaf_coeff_[i].size(); ++j) {
+ leaf_coeff_[i][j] = MaybeRoundToZero(leaf_coeff_[i][j] * rate);
+ }
+ }
}
leaf_value_[num_leaves_ - 1] =
MaybeRoundToZero(leaf_value_[num_leaves_ - 1] * rate);
+ if (is_linear_) {
+ leaf_const_[num_leaves_ - 1] = MaybeRoundToZero(leaf_const_[num_leaves_ - 1] * rate);
+ for (size_t j = 0; j < leaf_coeff_[num_leaves_ - 1].size(); ++j) {
+ leaf_coeff_[num_leaves_ - 1][j] = MaybeRoundToZero(leaf_coeff_[num_leaves_ - 1][j] * rate);
+ }
+ }
shrinkage_ *= rate;
}
@@ -205,6 +217,13 @@ class Tree {
}
leaf_value_[num_leaves_ - 1] =
MaybeRoundToZero(leaf_value_[num_leaves_ - 1] + val);
+ if (is_linear_) {
+#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
+ for (int i = 0; i < num_leaves_ - 1; ++i) {
+ leaf_const_[i] = MaybeRoundToZero(leaf_const_[i] + val);
+ }
+ leaf_const_[num_leaves_ - 1] = MaybeRoundToZero(leaf_const_[num_leaves_ - 1] + val);
+ }
// force to 1.0
shrinkage_ = 1.0f;
}
@@ -213,6 +232,9 @@ class Tree {
num_leaves_ = 1;
shrinkage_ = 1.0f;
leaf_value_[0] = val;
+ if (is_linear_) {
+ leaf_const_[0] = val;
+ }
}
/*! \brief Serialize this object to string*/
@@ -257,6 +279,47 @@ class Tree {
int NextLeafId() const { return num_leaves_; }
+ /*! \brief Get the linear model constant term (bias) of one leaf */
+ inline double LeafConst(int leaf) const { return leaf_const_[leaf]; }
+
+ /*! \brief Get the linear model coefficients of one leaf */
+ inline std::vector LeafCoeffs(int leaf) const { return leaf_coeff_[leaf]; }
+
+ /*! \brief Get the linear model features of one leaf */
+ inline std::vector LeafFeaturesInner(int leaf) const {return leaf_features_inner_[leaf]; }
+
+ /*! \brief Get the linear model features of one leaf */
+ inline std::vector LeafFeatures(int leaf) const {return leaf_features_[leaf]; }
+
+ /*! \brief Set the linear model coefficients on one leaf */
+ inline void SetLeafCoeffs(int leaf, const std::vector& output) {
+ leaf_coeff_[leaf].resize(output.size());
+ for (size_t i = 0; i < output.size(); ++i) {
+ leaf_coeff_[leaf][i] = MaybeRoundToZero(output[i]);
+ }
+ }
+
+ /*! \brief Set the linear model constant term (bias) on one leaf */
+ inline void SetLeafConst(int leaf, double output) {
+ leaf_const_[leaf] = MaybeRoundToZero(output);
+ }
+
+ /*! \brief Set the linear model features on one leaf */
+ inline void SetLeafFeaturesInner(int leaf, const std::vector& features) {
+ leaf_features_inner_[leaf] = features;
+ }
+
+ /*! \brief Set the linear model features on one leaf */
+ inline void SetLeafFeatures(int leaf, const std::vector& features) {
+ leaf_features_[leaf] = features;
+ }
+
+ inline bool is_linear() const { return is_linear_; }
+
+ inline void SetIsLinear(bool is_linear) {
+ is_linear_ = is_linear;
+ }
+
private:
std::string NumericalDecisionIfElse(int node) const;
@@ -454,6 +517,16 @@ class Tree {
std::vector> branch_features_;
double shrinkage_;
int max_depth_;
+ /*! \brief Tree has linear model at each leaf */
+ bool is_linear_;
+ /*! \brief coefficients of linear models on leaves */
+ std::vector> leaf_coeff_;
+ /*! \brief constant term (bias) of linear models on leaves */
+ std::vector leaf_const_;
+ /* \brief features used in leaf linear models; indexing is relative to num_total_features_ */
+ std::vector> leaf_features_;
+ /* \brief features used in leaf linear models; indexing is relative to used_features_ */
+ std::vector> leaf_features_inner_;
};
inline void Tree::Split(int leaf, int feature, int real_feature,
@@ -501,20 +574,65 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
}
inline double Tree::Predict(const double* feature_values) const {
- if (num_leaves_ > 1) {
- int leaf = GetLeaf(feature_values);
- return LeafOutput(leaf);
+ if (is_linear_) {
+ int leaf = GetLeaf(feature_values);
+ double output = leaf_const_[leaf];
+ bool nan_found = false;
+ for (size_t i = 0; i < leaf_features_[leaf].size(); ++i) {
+ int feat_raw = leaf_features_[leaf][i];
+ double feat_val = feature_values[feat_raw];
+ if (std::isnan(feat_val)) {
+ nan_found = true;
+ break;
+ } else {
+ output += leaf_coeff_[leaf][i] * feat_val;
+ }
+ }
+ if (nan_found) {
+ return LeafOutput(leaf);
+ } else {
+ return output;
+ }
} else {
- return leaf_value_[0];
+ if (num_leaves_ > 1) {
+ int leaf = GetLeaf(feature_values);
+ return LeafOutput(leaf);
+ } else {
+ return leaf_value_[0];
+ }
}
}
inline double Tree::PredictByMap(const std::unordered_map& feature_values) const {
- if (num_leaves_ > 1) {
+ if (is_linear_) {
int leaf = GetLeafByMap(feature_values);
- return LeafOutput(leaf);
+ double output = leaf_const_[leaf];
+ bool nan_found = false;
+ for (size_t i = 0; i < leaf_features_[leaf].size(); ++i) {
+ int feat = leaf_features_[leaf][i];
+ auto val_it = feature_values.find(feat);
+ if (val_it != feature_values.end()) {
+ double feat_val = val_it->second;
+ if (std::isnan(feat_val)) {
+ nan_found = true;
+ break;
+ } else {
+ output += leaf_coeff_[leaf][i] * feat_val;
+ }
+ }
+ }
+ if (nan_found) {
+ return LeafOutput(leaf);
+ } else {
+ return output;
+ }
} else {
- return leaf_value_[0];
+ if (num_leaves_ > 1) {
+ int leaf = GetLeafByMap(feature_values);
+ return LeafOutput(leaf);
+ } else {
+ return leaf_value_[0];
+ }
}
}
diff --git a/include/LightGBM/tree_learner.h b/include/LightGBM/tree_learner.h
index e0fb34890579..55343ad714ee 100644
--- a/include/LightGBM/tree_learner.h
+++ b/include/LightGBM/tree_learner.h
@@ -36,6 +36,9 @@ class TreeLearner {
*/
virtual void Init(const Dataset* train_data, bool is_constant_hessian) = 0;
+ /*! Initialise some temporary storage, only needed for the linear tree; needs to be a method of TreeLearner since we call it in GBDT::RefitTree */
+ virtual void InitLinear(const Dataset* /*train_data*/, const int /*max_leaves*/) {}
+
virtual void ResetIsConstantHessian(bool is_constant_hessian) = 0;
virtual void ResetTrainingData(const Dataset* train_data,
@@ -53,9 +56,10 @@ class TreeLearner {
* \brief training tree model on dataset
* \param gradients The first order gradients
* \param hessians The second order gradients
+ * \param is_first_tree If linear tree learning is enabled, first tree needs to be handled differently
* \return A trained tree
*/
- virtual Tree* Train(const score_t* gradients, const score_t* hessians) = 0;
+ virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_first_tree) = 0;
/*!
* \brief use an existing tree to fit the new gradients and hessians.
@@ -63,7 +67,7 @@ class TreeLearner {
virtual Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const = 0;
virtual Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred,
- const score_t* gradients, const score_t* hessians) = 0;
+ const score_t* gradients, const score_t* hessians) const = 0;
/*!
* \brief Set bagging data
@@ -94,11 +98,12 @@ class TreeLearner {
* \brief Create object of tree learner
* \param learner_type Type of tree learner
* \param device_type Type of tree learner
+ * \param booster_type Type of boosting
* \param config config of tree
*/
static TreeLearner* CreateTreeLearner(const std::string& learner_type,
- const std::string& device_type,
- const Config* config);
+ const std::string& device_type,
+ const Config* config);
};
} // namespace LightGBM
diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h
index 71574ff894b6..1278d172a25d 100644
--- a/include/LightGBM/utils/openmp_wrapper.h
+++ b/include/LightGBM/utils/openmp_wrapper.h
@@ -78,6 +78,7 @@ class ThreadExceptionHelper {
All #pragma omp should be ignored by the compiler **/
inline void omp_set_num_threads(int) {}
inline int omp_get_num_threads() {return 1;}
+ inline int omp_get_max_threads() {return 1;}
inline int omp_get_thread_num() {return 0;}
inline int OMP_NUM_THREADS() { return 1; }
#ifdef __cplusplus
diff --git a/python-package/MANIFEST.in b/python-package/MANIFEST.in
index 62bc7f64fcfb..cc0ab0626f91 100644
--- a/python-package/MANIFEST.in
+++ b/python-package/MANIFEST.in
@@ -10,6 +10,7 @@ include compile/compute/CMakeLists.txt
recursive-include compile/compute/cmake *
recursive-include compile/compute/include *
recursive-include compile/compute/meta *
+recursive-include compile/eigen *
include compile/external_libs/fast_double_parser/CMakeLists.txt
include compile/external_libs/fast_double_parser/LICENSE
include compile/external_libs/fast_double_parser/LICENSE.BSL
diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py
index 4e50fe0f5b81..a48350d9e280 100644
--- a/python-package/lightgbm/basic.py
+++ b/python-package/lightgbm/basic.py
@@ -1011,6 +1011,7 @@ def get_params(self):
"ignore_column",
"is_enable_sparse",
"label_column",
+ "linear_tree",
"max_bin",
"max_bin_by_feature",
"min_data_in_bin",
@@ -2999,8 +3000,13 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
predictor = self._to_predictor(copy.deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
- train_set = Dataset(data, label, silent=True)
+ out_is_linear = ctypes.c_bool(False)
+ _safe_call(_LIB.LGBM_BoosterGetLinear(
+ self.handle,
+ ctypes.byref(out_is_linear)))
new_params = copy.deepcopy(self.params)
+ new_params["linear_tree"] = out_is_linear.value
+ train_set = Dataset(data, label, silent=True, params=new_params)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
# Copy models
diff --git a/python-package/setup.py b/python-package/setup.py
index a2e8a0cf3560..32f4d12f92f5 100644
--- a/python-package/setup.py
+++ b/python-package/setup.py
@@ -65,6 +65,7 @@ def copy_files_helper(folder_name):
if not os.path.isfile(os.path.join(CURRENT_DIR, '_IS_SOURCE_PACKAGE.txt')):
copy_files_helper('include')
copy_files_helper('src')
+ copy_files_helper('eigen')
copy_files_helper('external_libs')
if not os.path.exists(os.path.join(CURRENT_DIR, "compile", "windows")):
os.makedirs(os.path.join(CURRENT_DIR, "compile", "windows"))
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index 84b6c0a9bd3a..cf49060f74e1 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -40,6 +40,7 @@ GBDT::GBDT()
bagging_runner_(0, bagging_rand_block_) {
average_output_ = false;
tree_learner_ = nullptr;
+ linear_tree_ = false;
}
GBDT::~GBDT() {
@@ -88,7 +89,8 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
is_constant_hessian_ = GetIsConstHessian(objective_function);
- tree_learner_ = std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get()));
+ tree_learner_ = std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type,
+ config_.get()));
// init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_);
@@ -129,6 +131,10 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
class_need_train_[i] = objective_function_->ClassNeedTrain(i);
}
}
+
+ if (config_->linear_tree) {
+ linear_tree_ = true;
+ }
}
void GBDT::AddValidDataset(const Dataset* valid_data,
@@ -282,6 +288,19 @@ void GBDT::RefitTree(const std::vector>& tree_leaf_prediction)
CHECK_EQ(static_cast(models_.size()), tree_leaf_prediction[0].size());
int num_iterations = static_cast(models_.size() / num_tree_per_iteration_);
std::vector leaf_pred(num_data_);
+ if (linear_tree_) {
+ std::vector max_leaves_by_thread = std::vector(OMP_NUM_THREADS(), 0);
+ #pragma omp parallel for schedule(static)
+ for (int i = 0; i < static_cast(tree_leaf_prediction.size()); ++i) {
+ int tid = omp_get_thread_num();
+ for (size_t j = 0; j < tree_leaf_prediction[i].size(); ++j) {
+ max_leaves_by_thread[tid] = std::max(max_leaves_by_thread[tid], tree_leaf_prediction[i][j]);
+ }
+ }
+ int max_leaves = *std::max_element(max_leaves_by_thread.begin(), max_leaves_by_thread.end());
+ max_leaves += 1;
+ tree_learner_->InitLinear(train_data_, max_leaves);
+ }
for (int iter = 0; iter < num_iterations; ++iter) {
Boosting();
for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
@@ -365,7 +384,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
bool should_continue = false;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
const size_t offset = static_cast(cur_tree_id) * num_data_;
- std::unique_ptr new_tree(new Tree(2, false));
+ std::unique_ptr new_tree(new Tree(2, false, false));
if (class_need_train_[cur_tree_id] && train_data_->num_features() > 0) {
auto grad = gradients + offset;
auto hess = hessians + offset;
@@ -378,7 +397,8 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
grad = gradients_.data() + offset;
hess = hessians_.data() + offset;
}
- new_tree.reset(tree_learner_->Train(grad, hess));
+ bool is_first_tree = models_.size() < static_cast(num_tree_per_iteration_);
+ new_tree.reset(tree_learner_->Train(grad, hess, is_first_tree));
}
if (new_tree->num_leaves() > 1) {
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index 0d38385d5f06..998bd5353ce2 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -44,6 +44,7 @@ class GBDT : public GBDTBase {
*/
~GBDT();
+
/*!
* \brief Initialization logic
* \param gbdt_config Config for boosting
@@ -391,6 +392,8 @@ class GBDT : public GBDTBase {
*/
const char* SubModelName() const override { return "tree"; }
+ bool IsLinear() const override { return linear_tree_; }
+
protected:
virtual bool GetIsConstHessian(const ObjectiveFunction* objective_function) {
if (objective_function != nullptr) {
@@ -530,6 +533,7 @@ class GBDT : public GBDTBase {
std::vector bagging_rands_;
ParallelPartitionRunner bagging_runner_;
Json forced_splits_json_;
+ bool linear_tree_;
};
} // namespace LightGBM
diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp
index e5cec8b61db0..8b9a2b8b450b 100644
--- a/src/boosting/gbdt_model_text.cpp
+++ b/src/boosting/gbdt_model_text.cpp
@@ -581,6 +581,11 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
break;
} else if (is_inparameter) {
ss << cur_line << "\n";
+ if (Common::StartsWith(cur_line, "[linear_tree: ")) {
+ int is_linear = 0;
+ Common::Atoi(cur_line.substr(14, 1).c_str(), &is_linear);
+ linear_tree_ = static_cast(is_linear);
+ }
}
}
p += line_len;
diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp
index 0768fe10ca31..5a9eb226fef5 100644
--- a/src/boosting/rf.hpp
+++ b/src/boosting/rf.hpp
@@ -109,7 +109,7 @@ class RF : public GBDT {
gradients = gradients_.data();
hessians = hessians_.data();
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
- std::unique_ptr new_tree(new Tree(2, false));
+ std::unique_ptr new_tree(new Tree(2, false, false));
size_t offset = static_cast(cur_tree_id)* num_data_;
if (class_need_train_[cur_tree_id]) {
auto grad = gradients + offset;
@@ -125,7 +125,7 @@ class RF : public GBDT {
hess = tmp_hess_.data();
}
- new_tree.reset(tree_learner_->Train(grad, hess));
+ new_tree.reset(tree_learner_->Train(grad, hess, false));
}
if (new_tree->num_leaves() > 1) {
diff --git a/src/c_api.cpp b/src/c_api.cpp
index 5cd0eab3d77f..4cf9a5e5072a 100644
--- a/src/c_api.cpp
+++ b/src/c_api.cpp
@@ -289,6 +289,9 @@ class Booster {
"You need to set `feature_pre_filter=false` to dynamically change "
"the `min_data_in_leaf`.");
}
+ if (new_param.count("linear_tree") && (new_config.linear_tree != old_config.linear_tree)) {
+ Log:: Fatal("Cannot change between gbdt_linear boosting and other boosting types after Dataset handle has been constructed.");
+ }
}
void ResetConfig(const char* parameters) {
@@ -960,6 +963,9 @@ int LGBM_DatasetPushRows(DatasetHandle dataset,
API_BEGIN();
auto p_dataset = reinterpret_cast(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
+ if (p_dataset->has_raw()) {
+ p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
+ }
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
@@ -990,14 +996,16 @@ int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
auto p_dataset = reinterpret_cast(dataset);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast(nindptr - 1);
+ if (p_dataset->has_raw()) {
+ p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
+ }
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
auto one_row = get_row_fun(i);
- p_dataset->PushOneRow(tid,
- static_cast(start_row + i), one_row);
+ p_dataset->PushOneRow(tid, static_cast(start_row + i), one_row);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
@@ -1090,6 +1098,9 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
ret.reset(new Dataset(total_nrow));
ret->CreateValid(
reinterpret_cast(reference));
+ if (ret->has_raw()) {
+ ret->ResizeRaw(total_nrow);
+ }
}
int32_t start_row = 0;
for (int j = 0; j < nmat; ++j) {
@@ -1166,6 +1177,9 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
ret.reset(new Dataset(nrow));
ret->CreateValid(
reinterpret_cast(reference));
+ if (ret->has_raw()) {
+ ret->ResizeRaw(nrow);
+ }
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
@@ -1234,6 +1248,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
ret.reset(new Dataset(nrow));
ret->CreateValid(
reinterpret_cast(reference));
+ if (ret->has_raw()) {
+ ret->ResizeRaw(nrow);
+ }
}
OMP_INIT_EX();
@@ -1326,12 +1343,12 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
row_idx = pair.first;
// no more data
if (row_idx < 0) { break; }
- ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
+ ret->PushOneData(tid, row_idx, group, feature_idx, sub_feature, pair.second);
}
} else {
for (int row_idx = 0; row_idx < nrow; ++row_idx) {
auto val = col_it.Get(row_idx);
- ret->PushOneData(tid, row_idx, group, sub_feature, val);
+ ret->PushOneData(tid, row_idx, group, feature_idx, sub_feature, val);
}
}
OMP_LOOP_EX_END();
@@ -1600,6 +1617,13 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_END();
}
+int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out) {
+ API_BEGIN();
+ Booster* ref_booster = reinterpret_cast(handle);
+ *out = ref_booster->GetBoosting()->IsLinear();
+ API_END();
+}
+
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast(handle);
diff --git a/src/io/config.cpp b/src/io/config.cpp
index 45987b6760c5..ce034505ee4b 100644
--- a/src/io/config.cpp
+++ b/src/io/config.cpp
@@ -335,13 +335,28 @@ void Config::CheckParamConflict() {
Log::Warning("Although \"deterministic\" is set, the results ran by GPU may be non-deterministic.");
}
}
-
// force gpu_use_dp for CUDA
if (device_type == std::string("cuda") && !gpu_use_dp) {
Log::Warning("CUDA currently requires double precision calculations.");
gpu_use_dp = true;
}
-
+ // linear tree learner must be serial type and cpu device
+ if (linear_tree) {
+ if (device_type == std::string("gpu")) {
+ device_type = "cpu";
+ Log::Warning("Linear tree learner only works with CPU.");
+ }
+ if (tree_learner != std::string("serial")) {
+ tree_learner = "serial";
+ Log::Warning("Linear tree learner must be serial.");
+ }
+ if (zero_as_missing) {
+ Log::Fatal("zero_as_missing must be false when fitting linear trees.");
+ }
+ if (objective == std::string("regresson_l1")) {
+ Log::Fatal("Cannot use regression_l1 objective when fitting linear trees.");
+ }
+ }
// min_data_in_leaf must be at least 2 if path smoothing is active. This is because when the split is calculated
// the count is calculated using the proportion of hessian in the leaf which is rounded up to nearest int, so it can
// be 1 when there is actually no data in the leaf. In rare cases this can cause a bug because with path smoothing the
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index b4a55ce589ef..06c53e84268a 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -174,6 +174,7 @@ const std::unordered_set& Config::parameter_set() {
"task",
"objective",
"boosting",
+ "linear_tree",
"data",
"valid",
"num_iterations",
@@ -205,6 +206,7 @@ const std::unordered_set& Config::parameter_set() {
"max_delta_step",
"lambda_l1",
"lambda_l2",
+ "linear_lambda",
"min_gain_to_split",
"drop_rate",
"max_drop",
@@ -304,6 +306,8 @@ const std::unordered_set& Config::parameter_set() {
void Config::GetMembersFromString(const std::unordered_map& params) {
std::string tmp_str = "";
+ GetBool(params, "linear_tree", &linear_tree);
+
GetString(params, "data", &data);
if (GetString(params, "valid", &tmp_str)) {
@@ -380,6 +384,9 @@ void Config::GetMembersFromString(const std::unordered_map>* bin_mappers,
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
use_missing_ = io_config.use_missing;
zero_as_missing_ = io_config.zero_as_missing;
+ has_raw_ = false;
+ if (io_config.linear_tree) {
+ has_raw_ = true;
+ }
+ numeric_feature_map_ = std::vector(num_features_, -1);
+ num_numeric_features_ = 0;
+ for (int i = 0; i < num_features_; ++i) {
+ if (FeatureBinMapper(i)->bin_type() == BinType::NumericalBin) {
+ numeric_feature_map_[i] = num_numeric_features_;
+ ++num_numeric_features_;
+ }
+ }
}
void Dataset::FinishLoad() {
@@ -696,6 +710,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
feature_groups_.clear();
num_features_ = dataset->num_features_;
num_groups_ = dataset->num_groups_;
+ has_raw_ = dataset->has_raw();
// copy feature bin mapper data
for (int i = 0; i < num_groups_; ++i) {
feature_groups_.emplace_back(
@@ -722,7 +737,10 @@ void Dataset::CreateValid(const Dataset* dataset) {
num_groups_ = num_features_;
feature2group_.clear();
feature2subfeature_.clear();
-
+ has_raw_ = dataset->has_raw();
+ numeric_feature_map_ = dataset->numeric_feature_map_;
+ num_numeric_features_ = dataset->num_numeric_features_;
+ // copy feature bin mapper data
feature_need_push_zeros_.clear();
group_bin_boundaries_.clear();
uint64_t num_total_bin = 0;
@@ -785,6 +803,17 @@ void Dataset::CopySubrow(const Dataset* fullset,
metadata_.Init(fullset->metadata_, used_indices, num_used_indices);
}
is_finish_load_ = true;
+ numeric_feature_map_ = fullset->numeric_feature_map_;
+ num_numeric_features_ = fullset->num_numeric_features_;
+ if (has_raw_) {
+ ResizeRaw(num_used_indices);
+#pragma omp parallel for schedule(static)
+ for (int i = 0; i < num_used_indices; ++i) {
+ for (int j = 0; j < num_numeric_features_; ++j) {
+ raw_data_[j][i] = fullset->raw_data_[j][used_indices[i]];
+ }
+ }
+ }
}
bool Dataset::SetFloatField(const char* field_name, const float* field_data,
@@ -922,8 +951,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
2 * VirtualFileWriter::AlignedSize(sizeof(int) * num_groups_) +
VirtualFileWriter::AlignedSize(sizeof(int32_t) * num_total_features_) +
VirtualFileWriter::AlignedSize(sizeof(int)) * 3 +
- VirtualFileWriter::AlignedSize(sizeof(bool)) * 2;
-
+ VirtualFileWriter::AlignedSize(sizeof(bool)) * 3;
// size of feature names
for (int i = 0; i < num_total_features_; ++i) {
size_of_header +=
@@ -947,6 +975,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
writer->AlignedWrite(&min_data_in_bin_, sizeof(min_data_in_bin_));
writer->AlignedWrite(&use_missing_, sizeof(use_missing_));
writer->AlignedWrite(&zero_as_missing_, sizeof(zero_as_missing_));
+ writer->AlignedWrite(&has_raw_, sizeof(has_raw_));
writer->AlignedWrite(used_feature_map_.data(),
sizeof(int) * num_total_features_);
writer->AlignedWrite(&num_groups_, sizeof(num_groups_));
@@ -998,6 +1027,18 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
// write feature
feature_groups_[i]->SaveBinaryToFile(writer.get());
}
+
+ // write raw data; use row-major order so we can read row-by-row
+ if (has_raw_) {
+ for (int i = 0; i < num_data_; ++i) {
+ for (int j = 0; j < num_features_; ++j) {
+ int feat_ind = numeric_feature_map_[j];
+ if (feat_ind > -1) {
+ writer->Write(&raw_data_[feat_ind][i], sizeof(float));
+ }
+ }
+ }
+ }
}
}
@@ -1264,6 +1305,9 @@ void Dataset::AddFeaturesFrom(Dataset* other) {
"Cannot add features from other Dataset with a different number of "
"rows");
}
+ if (other->has_raw_ != has_raw_) {
+ Log::Fatal("Can only add features from other Dataset if both or neither have raw data.");
+ }
int mv_gid = -1;
int other_mv_gid = -1;
for (int i = 0; i < num_groups_; ++i) {
@@ -1393,6 +1437,20 @@ void Dataset::AddFeaturesFrom(Dataset* other) {
PushClearIfEmpty(&max_bin_by_feature_, num_total_features_,
other->max_bin_by_feature_, other->num_total_features_, -1);
num_total_features_ += other->num_total_features_;
+ for (size_t i = 0; i < (other->numeric_feature_map_).size(); ++i) {
+ int feat_ind = numeric_feature_map_[i];
+ if (feat_ind > -1) {
+ numeric_feature_map_.push_back(feat_ind + num_numeric_features_);
+ } else {
+ numeric_feature_map_.push_back(-1);
+ }
+ }
+ num_numeric_features_ += other->num_numeric_features_;
+ if (has_raw_) {
+ for (int i = 0; i < other->num_numeric_features_; ++i) {
+ raw_data_.push_back(other->raw_data_[i]);
+ }
+ }
}
} // namespace LightGBM
diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp
index ba707329f7d2..24c9637bfa56 100644
--- a/src/io/dataset_loader.cpp
+++ b/src/io/dataset_loader.cpp
@@ -22,6 +22,10 @@ DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& pre
weight_idx_ = NO_SPECIFIC;
group_idx_ = NO_SPECIFIC;
SetHeader(filename);
+ store_raw_ = false;
+ if (io_config.linear_tree) {
+ store_raw_ = true;
+ }
}
DatasetLoader::~DatasetLoader() {
@@ -183,6 +187,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
}
}
auto dataset = std::unique_ptr(new Dataset());
+ if (store_raw_) {
+ dataset->SetHasRaw(true);
+ }
data_size_t num_global_data = 0;
std::vector used_data_indices;
auto bin_filename = CheckCanLoadFromBin(filename);
@@ -205,6 +212,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
static_cast(dataset->num_data_));
// construct feature bin mappers
ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(dataset->num_data_);
+ }
// initialize label
dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
// extract features
@@ -222,6 +232,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
static_cast(dataset->num_data_));
// construct feature bin mappers
ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(dataset->num_data_);
+ }
// initialize label
dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Log::Debug("Making second pass...");
@@ -248,6 +261,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
data_size_t num_global_data = 0;
std::vector used_data_indices;
auto dataset = std::unique_ptr(new Dataset());
+ if (store_raw_) {
+ dataset->SetHasRaw(true);
+ }
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_));
@@ -264,6 +280,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
// initialize label
dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
dataset->CreateValid(train_data);
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(dataset->num_data_);
+ }
// extract features
ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
text_data.clear();
@@ -275,6 +294,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
// initialize label
dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
dataset->CreateValid(train_data);
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(dataset->num_data_);
+ }
// extract features
ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
}
@@ -356,6 +378,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
dataset->zero_as_missing_ = *(reinterpret_cast(mem_ptr));
mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
+ dataset->has_raw_ = *(reinterpret_cast(mem_ptr));
+ mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
const int* tmp_feature_map = reinterpret_cast(mem_ptr);
dataset->used_feature_map_.clear();
for (int i = 0; i < dataset->num_total_features_; ++i) {
@@ -550,6 +574,43 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
*used_data_indices, i)));
}
dataset->feature_groups_.shrink_to_fit();
+
+ // raw data
+ dataset->numeric_feature_map_ = std::vector(dataset->num_features_, false);
+ dataset->num_numeric_features_ = 0;
+ for (int i = 0; i < dataset->num_features_; ++i) {
+ if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
+ dataset->numeric_feature_map_[i] = -1;
+ } else {
+ dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
+ ++dataset->num_numeric_features_;
+ }
+ }
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(dataset->num_data());
+ size_t row_size = dataset->num_numeric_features_ * sizeof(float);
+ if (row_size > buffer_size) {
+ buffer_size = row_size;
+ buffer.resize(buffer_size);
+ }
+ for (int i = 0; i < dataset->num_data(); ++i) {
+ read_cnt = reader->Read(buffer.data(), row_size);
+ if (read_cnt != row_size) {
+ Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt);
+ }
+ mem_ptr = buffer.data();
+ const float* tmp_ptr_raw_row = reinterpret_cast(mem_ptr);
+ std::vector curr_row(dataset->num_numeric_features_, 0);
+ for (int j = 0; j < dataset->num_features(); ++j) {
+ int feat_ind = dataset->numeric_feature_map_[j];
+ if (feat_ind >= 0) {
+ dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
+ }
+ }
+ mem_ptr += row_size;
+ }
+ }
+
dataset->is_finish_load_ = true;
return dataset.release();
}
@@ -704,6 +765,9 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
}
auto dataset = std::unique_ptr(new Dataset(num_data));
dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(num_data);
+ }
dataset->set_feature_names(feature_names_);
return dataset.release();
}
@@ -1061,6 +1125,9 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr(&sample_indices).data(),
Common::Vector2Ptr(&sample_values).data(),
Common::VectorSize(sample_indices).data(), static_cast(sample_indices.size()), sample_data.size(), config_);
+ if (dataset->has_raw()) {
+ dataset->ResizeRaw(sample_data.size());
+ }
}
/*! \brief Extract local features from memory */
@@ -1068,10 +1135,11 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat
std::vector> oneline_features;
double tmp_label = 0.0f;
auto& ref_text_data = *text_data;
+ std::vector feature_row(dataset->num_features_);
if (predict_fun_ == nullptr) {
OMP_INIT_EX();
// if doesn't need to prediction with initial model
- #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
+ #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
for (data_size_t i = 0; i < dataset->num_data_; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
@@ -1095,6 +1163,9 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat
int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx];
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
+ if (dataset->has_raw()) {
+ feature_row[feature_idx] = inner_data.second;
+ }
} else {
if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast(inner_data.second));
@@ -1103,6 +1174,14 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat
}
}
}
+ if (dataset->has_raw()) {
+ for (size_t j = 0; j < feature_row.size(); ++j) {
+ int feat_ind = dataset->numeric_feature_map_[j];
+ if (feat_ind >= 0) {
+ dataset->raw_data_[feat_ind][i] = feature_row[j];
+ }
+ }
+ }
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END();
}
@@ -1111,7 +1190,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat
OMP_INIT_EX();
// if need to prediction with initial model
std::vector init_score(dataset->num_data_ * num_class_);
- #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
+ #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
for (data_size_t i = 0; i < dataset->num_data_; ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
@@ -1141,6 +1220,9 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat
int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx];
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
+ if (dataset->has_raw()) {
+ feature_row[feature_idx] = inner_data.second;
+ }
} else {
if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast(inner_data.second));
@@ -1150,6 +1232,14 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat
}
}
dataset->FinishOneRow(tid, i, is_feature_added);
+ if (dataset->has_raw()) {
+ for (size_t j = 0; j < feature_row.size(); ++j) {
+ int feat_ind = dataset->numeric_feature_map_[j];
+ if (feat_ind >= 0) {
+ dataset->raw_data_[feat_ind][i] = feature_row[j];
+ }
+ }
+ }
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
@@ -1173,8 +1263,9 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
(data_size_t start_idx, const std::vector& lines) {
std::vector> oneline_features;
double tmp_label = 0.0f;
+ std::vector feature_row(dataset->num_features_);
OMP_INIT_EX();
- #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
+ #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
for (data_size_t i = 0; i < static_cast(lines.size()); ++i) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
@@ -1202,6 +1293,9 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx];
dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
+ if (dataset->has_raw()) {
+ feature_row[feature_idx] = inner_data.second;
+ }
} else {
if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(start_idx + i, static_cast(inner_data.second));
@@ -1210,6 +1304,14 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
}
}
}
+ if (dataset->has_raw()) {
+ for (size_t j = 0; j < feature_row.size(); ++j) {
+ int feat_ind = dataset->numeric_feature_map_[j];
+ if (feat_ind >= 0) {
+ dataset->raw_data_[feat_ind][i] = feature_row[j];
+ }
+ }
+ }
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END();
}
diff --git a/src/io/tree.cpp b/src/io/tree.cpp
index a4e42831b0e7..91eb0a0ab116 100644
--- a/src/io/tree.cpp
+++ b/src/io/tree.cpp
@@ -14,7 +14,7 @@
namespace LightGBM {
-Tree::Tree(int max_leaves, bool track_branch_features)
+Tree::Tree(int max_leaves, bool track_branch_features, bool is_linear)
:max_leaves_(max_leaves), track_branch_features_(track_branch_features) {
left_child_.resize(max_leaves_ - 1);
right_child_.resize(max_leaves_ - 1);
@@ -46,6 +46,13 @@ Tree::Tree(int max_leaves, bool track_branch_features)
cat_boundaries_.push_back(0);
cat_boundaries_inner_.push_back(0);
max_depth_ = -1;
+ is_linear_ = is_linear;
+ if (is_linear_) {
+ leaf_coeff_.resize(max_leaves_);
+ leaf_const_ = std::vector(max_leaves_, 0);
+ leaf_features_.resize(max_leaves_);
+ leaf_features_inner_.resize(max_leaves_);
+ }
}
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
@@ -103,8 +110,42 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32
score[(data_idx)] += static_cast(leaf_value_[~node]); \
}\
+
+#define PredictionFunLinear(niter, fidx_in_iter, start_pos, decision_fun, \
+ iter_idx, data_idx) \
+ std::vector> iter((niter)); \
+ for (int i = 0; i < (niter); ++i) { \
+ iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
+ iter[i]->Reset((start_pos)); \
+ } \
+ for (data_size_t i = start; i < end; ++i) { \
+ int node = 0; \
+ while (node >= 0) { \
+ node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, \
+ default_bins[node], max_bins[node]); \
+ } \
+ double add_score = leaf_const_[~node]; \
+ bool nan_found = false; \
+ const double* coeff_ptr = leaf_coeff_[~node].data(); \
+ const float** data_ptr = feat_ptr[~node].data(); \
+ for (size_t j = 0; j < leaf_features_inner_[~node].size(); ++j) { \
+ float feat_val = data_ptr[j][(data_idx)]; \
+ if (std::isnan(feat_val)) { \
+ nan_found = true; \
+ break; \
+ } \
+ add_score += coeff_ptr[j] * feat_val; \
+ } \
+ if (nan_found) { \
+ score[(data_idx)] += leaf_value_[~node]; \
+ } else { \
+ score[(data_idx)] += add_score; \
+ } \
+}\
+
+
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
- if (num_leaves_ <= 1) {
+ if (!is_linear_ && num_leaves_ <= 1) {
if (leaf_value_[0] != 0.0f) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
@@ -121,37 +162,71 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl
default_bins[i] = bin_mapper->GetDefaultBin();
max_bins[i] = bin_mapper->num_bin() - 1;
}
- if (num_cat_ > 0) {
- if (data->num_features() > num_leaves_ - 1) {
- Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
- });
+ if (is_linear_) {
+ std::vector> feat_ptr(num_leaves_);
+ for (int leaf_num = 0; leaf_num < num_leaves_; ++leaf_num) {
+ for (int feat : leaf_features_inner_[leaf_num]) {
+ feat_ptr[leaf_num].push_back(data->raw_index(feat));
+ }
+ }
+ if (num_cat_ > 0) {
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
+ });
+ }
} else {
- Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
- });
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
+ });
+ }
}
} else {
- if (data->num_features() > num_leaves_ - 1) {
- Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
- });
+ if (num_cat_ > 0) {
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
+ });
+ }
} else {
- Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
- });
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
+ });
+ }
}
}
}
void Tree::AddPredictionToScore(const Dataset* data,
- const data_size_t* used_data_indices,
- data_size_t num_data, double* score) const {
- if (num_leaves_ <= 1) {
+ const data_size_t* used_data_indices,
+ data_size_t num_data, double* score) const {
+ if (!is_linear_ && num_leaves_ <= 1) {
if (leaf_value_[0] != 0.0f) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
@@ -168,34 +243,72 @@ void Tree::AddPredictionToScore(const Dataset* data,
default_bins[i] = bin_mapper->GetDefaultBin();
max_bins[i] = bin_mapper->num_bin() - 1;
}
- if (num_cat_ > 0) {
- if (data->num_features() > num_leaves_ - 1) {
- Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
- });
+ if (is_linear_) {
+ std::vector> feat_ptr(num_leaves_);
+ for (int leaf_num = 0; leaf_num < num_leaves_; ++leaf_num) {
+ for (int feat : leaf_features_inner_[leaf_num]) {
+ feat_ptr[leaf_num].push_back(data->raw_index(feat));
+ }
+ }
+ if (num_cat_ > 0) {
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner,
+ node, used_data_indices[i]);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
+ });
+ }
} else {
- Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
- });
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner,
+ node, used_data_indices[i]);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFunLinear(data->num_features(), i, used_data_indices[start], NumericalDecisionInner,
+ split_feature_inner_[node], used_data_indices[i]);
+ });
+ }
}
} else {
- if (data->num_features() > num_leaves_ - 1) {
- Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
- });
+ if (num_cat_ > 0) {
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
+ });
+ }
} else {
- Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
- (int, data_size_t start, data_size_t end) {
- PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
- });
+ if (data->num_features() > num_leaves_ - 1) {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
+ });
+ } else {
+ Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
+ (int, data_size_t start, data_size_t end) {
+ PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
+ });
+ }
}
}
}
#undef PredictionFun
+#undef PredictionFunLinear
double Tree::GetUpperBoundValue() const {
double upper_bound = leaf_value_[0];
@@ -259,8 +372,37 @@ std::string Tree::ToString() const {
str_buf << "cat_threshold="
<< ArrayToString(cat_threshold_, cat_threshold_.size()) << '\n';
}
+ str_buf << "is_linear=" << is_linear_ << '\n';
+
+ if (is_linear_) {
+ str_buf << "leaf_const="
+ << ArrayToString(leaf_const_, num_leaves_) << '\n';
+ std::vector num_feat(num_leaves_);
+ for (int i = 0; i < num_leaves_; ++i) {
+ num_feat[i] = leaf_coeff_[i].size();
+ }
+ str_buf << "num_features="
+ << ArrayToString(num_feat, num_leaves_) << '\n';
+ str_buf << "leaf_features=";
+ for (int i = 0; i < num_leaves_; ++i) {
+ if (num_feat[i] > 0) {
+ str_buf << ArrayToString(leaf_features_[i], leaf_features_[i].size()) << ' ';
+ }
+ str_buf << ' ';
+ }
+ str_buf << '\n';
+ str_buf << "leaf_coeff=";
+ for (int i = 0; i < num_leaves_; ++i) {
+ if (num_feat[i] > 0) {
+ str_buf << ArrayToString(leaf_coeff_[i], leaf_coeff_[i].size()) << ' ';
+ }
+ str_buf << ' ';
+ }
+ str_buf << '\n';
+ }
str_buf << "shrinkage=" << shrinkage_ << '\n';
str_buf << '\n';
+
return str_buf.str();
}
@@ -508,7 +650,7 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
Tree::Tree(const char* str, size_t* used_len) {
auto p = str;
std::unordered_map key_vals;
- const int max_num_line = 17;
+ const int max_num_line = 22;
int read_line = 0;
while (read_line < max_num_line) {
if (*p == '\r' || *p == '\n') break;
@@ -549,7 +691,13 @@ Tree::Tree(const char* str, size_t* used_len) {
shrinkage_ = 1.0f;
}
- if (num_leaves_ <= 1) { return; }
+ if (key_vals.count("is_linear")) {
+ int is_linear_int;
+ Common::Atoi(key_vals["is_linear"].c_str(), &is_linear_int);
+ is_linear_ = static_cast(is_linear_int);
+ }
+
+ if ((num_leaves_ <= 1) && !is_linear_) { return; }
if (key_vals.count("left_child")) {
left_child_ = CommonC::StringToArrayFast(key_vals["left_child"], num_leaves_ - 1);
@@ -617,6 +765,45 @@ Tree::Tree(const char* str, size_t* used_len) {
decision_type_ = std::vector(num_leaves_ - 1, 0);
}
+ if (is_linear_) {
+ if (key_vals.count("leaf_const")) {
+ leaf_const_ = Common::StringToArrayFast(key_vals["leaf_const"], num_leaves_);
+ } else {
+ leaf_const_.resize(num_leaves_);
+ }
+ std::vector num_feat;
+ if (key_vals.count("num_features")) {
+ num_feat = Common::StringToArrayFast(key_vals["num_features"], num_leaves_);
+ }
+ leaf_coeff_.resize(num_leaves_);
+ leaf_features_.resize(num_leaves_);
+ leaf_features_inner_.resize(num_leaves_);
+ if (num_feat.size() > 0) {
+ int total_num_feat = 0;
+ for (size_t i = 0; i < num_feat.size(); ++i) { total_num_feat += num_feat[i]; }
+ std::vector all_leaf_features;
+ if (key_vals.count("leaf_features")) {
+ all_leaf_features = Common::StringToArrayFast(key_vals["leaf_features"], total_num_feat);
+ }
+ std::vector all_leaf_coeff;
+ if (key_vals.count("leaf_coeff")) {
+ all_leaf_coeff = Common::StringToArrayFast(key_vals["leaf_coeff"], total_num_feat);
+ }
+ int sum_num_feat = 0;
+ for (int i = 0; i < num_leaves_; ++i) {
+ if (num_feat[i] > 0) {
+ if (key_vals.count("leaf_features")) {
+ leaf_features_[i].assign(all_leaf_features.begin() + sum_num_feat, all_leaf_features.begin() + sum_num_feat + num_feat[i]);
+ }
+ if (key_vals.count("leaf_coeff")) {
+ leaf_coeff_[i].assign(all_leaf_coeff.begin() + sum_num_feat, all_leaf_coeff.begin() + sum_num_feat + num_feat[i]);
+ }
+ }
+ sum_num_feat += num_feat[i];
+ }
+ }
+ }
+
if (num_cat_ > 0) {
if (key_vals.count("cat_boundaries")) {
cat_boundaries_ = CommonC::StringToArrayFast(key_vals["cat_boundaries"], num_cat_ + 1);
diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp
index df90aafb945c..efe36072544b 100644
--- a/src/treelearner/gpu_tree_learner.cpp
+++ b/src/treelearner/gpu_tree_learner.cpp
@@ -734,8 +734,8 @@ void GPUTreeLearner::InitGPU(int platform_id, int device_id) {
SetupKernelArguments();
}
-Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
- return SerialTreeLearner::Train(gradients, hessians);
+Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) {
+ return SerialTreeLearner::Train(gradients, hessians, is_first_tree);
}
void GPUTreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) {
diff --git a/src/treelearner/gpu_tree_learner.h b/src/treelearner/gpu_tree_learner.h
index a909c57cbadc..91980ce934ea 100644
--- a/src/treelearner/gpu_tree_learner.h
+++ b/src/treelearner/gpu_tree_learner.h
@@ -48,7 +48,7 @@ class GPUTreeLearner: public SerialTreeLearner {
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override;
void ResetIsConstantHessian(bool is_constant_hessian) override;
- Tree* Train(const score_t* gradients, const score_t *hessians) override;
+ Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override;
void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(subset, used_indices, num_data);
diff --git a/src/treelearner/linear_tree_learner.cpp b/src/treelearner/linear_tree_learner.cpp
new file mode 100644
index 000000000000..58482b8eef9b
--- /dev/null
+++ b/src/treelearner/linear_tree_learner.cpp
@@ -0,0 +1,390 @@
+/*!
+ * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#include "linear_tree_learner.h"
+
+#include
+
+#ifndef LGB_R_BUILD
+// preprocessor definition ensures we use only MPL2-licensed code
+#define EIGEN_MPL2_ONLY
+#include
+#endif // !LGB_R_BUILD
+
+namespace LightGBM {
+
+void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
+ SerialTreeLearner::Init(train_data, is_constant_hessian);
+ LinearTreeLearner::InitLinear(train_data, config_->num_leaves);
+}
+
+void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) {
+ leaf_map_ = std::vector(train_data->num_data(), -1);
+ contains_nan_ = std::vector(train_data->num_features(), 0);
+ // identify features containing nans
+#pragma omp parallel for schedule(static)
+ for (int feat = 0; feat < train_data->num_features(); ++feat) {
+ auto bin_mapper = train_data_->FeatureBinMapper(feat);
+ if (bin_mapper->bin_type() == BinType::NumericalBin) {
+ const float* feat_ptr = train_data_->raw_index(feat);
+ for (int i = 0; i < train_data->num_data(); ++i) {
+ if (std::isnan(feat_ptr[i])) {
+ contains_nan_[feat] = 1;
+ break;
+ }
+ }
+ }
+ }
+ for (int feat = 0; feat < train_data->num_features(); ++feat) {
+ if (contains_nan_[feat]) {
+ any_nan_ = true;
+ break;
+ }
+ }
+ // preallocate the matrix used to calculate linear model coefficients
+ int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features());
+ XTHX_.clear();
+ XTg_.clear();
+ for (int i = 0; i < max_leaves; ++i) {
+ // store only upper triangular half of matrix as an array, in row-major order
+ // this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression)
+ // we add another 8 to ensure cache lines are not shared among processors
+ XTHX_.push_back(std::vector((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0));
+ XTg_.push_back(std::vector(max_num_feat + 9, 0.0));
+ }
+ XTHX_by_thread_.clear();
+ XTg_by_thread_.clear();
+ int max_threads = omp_get_max_threads();
+ for (int i = 0; i < max_threads; ++i) {
+ XTHX_by_thread_.push_back(XTHX_);
+ XTg_by_thread_.push_back(XTg_);
+ }
+}
+
+Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) {
+ Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
+ gradients_ = gradients;
+ hessians_ = hessians;
+ int num_threads = OMP_NUM_THREADS();
+ if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) {
+ Log::Warning(
+ "Detected that num_threads changed during training (from %d to %d), "
+ "it may cause unexpected errors.",
+ share_state_->num_threads, num_threads);
+ }
+ share_state_->num_threads = num_threads;
+
+ // some initial works before training
+ BeforeTrain();
+
+ auto tree = std::unique_ptr(new Tree(config_->num_leaves, true, true));
+ auto tree_ptr = tree.get();
+ constraints_->ShareTreePointer(tree_ptr);
+
+ // root leaf
+ int left_leaf = 0;
+ int cur_depth = 1;
+ // only root leaf can be splitted on first time
+ int right_leaf = -1;
+
+ int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth);
+
+ for (int split = init_splits; split < config_->num_leaves - 1; ++split) {
+ // some initial works before finding best split
+ if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) {
+ // find best threshold for every feature
+ FindBestSplits(tree_ptr);
+ }
+ // Get a leaf with max split gain
+ int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_));
+ // Get split information for best leaf
+ const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
+ // cannot split, quit
+ if (best_leaf_SplitInfo.gain <= 0.0) {
+ Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
+ break;
+ }
+ // split tree with best leaf
+ Split(tree_ptr, best_leaf, &left_leaf, &right_leaf);
+ cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
+ }
+
+ bool has_nan = false;
+ if (any_nan_) {
+ for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
+ if (contains_nan_[tree_ptr->split_feature_inner(i)]) {
+ has_nan = true;
+ break;
+ }
+ }
+ }
+
+ GetLeafMap(tree_ptr);
+
+ if (has_nan) {
+ CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree);
+ } else {
+ CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree);
+ }
+
+ Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth);
+ return tree.release();
+}
+
+Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
+ auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians);
+ bool has_nan = false;
+ if (any_nan_) {
+ for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
+ if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
+ has_nan = true;
+ break;
+ }
+ }
+ }
+ GetLeafMap(tree);
+ if (has_nan) {
+ CalculateLinear(tree, true, gradients, hessians, false);
+ } else {
+ CalculateLinear(tree, true, gradients, hessians, false);
+ }
+ return tree;
+}
+
+Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred,
+ const score_t* gradients, const score_t *hessians) const {
+ data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
+ return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians);
+}
+
+void LinearTreeLearner::GetLeafMap(Tree* tree) const {
+ std::fill(leaf_map_.begin(), leaf_map_.end(), -1);
+ // map data to leaf number
+ const data_size_t* ind = data_partition_->indices();
+#pragma omp parallel for schedule(dynamic)
+ for (int i = 0; i < tree->num_leaves(); ++i) {
+ data_size_t idx = data_partition_->leaf_begin(i);
+ for (int j = 0; j < data_partition_->leaf_count(i); ++j) {
+ leaf_map_[ind[idx + j]] = i;
+ }
+ }
+}
+
+#ifdef LGB_R_BUILD
+template
+void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
+ Log::Fatal("Linear tree learner does not work with R package.");
+}
+#else
+template
+void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
+ tree->SetIsLinear(true);
+ int num_leaves = tree->num_leaves();
+ int num_threads = OMP_NUM_THREADS();
+ if (is_first_tree) {
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num));
+ }
+ return;
+ }
+
+ // calculate coefficients using the method described in Eq 3 of https://arxiv.org/pdf/1802.05640.pdf
+ // the coefficients vector is given by
+ // - (X_T * H * X + lambda) ^ (-1) * (X_T * g)
+ // where:
+ // X is the matrix where the first column is the feature values and the second is all ones,
+ // H is the diagonal matrix of the hessian,
+ // lambda is the diagonal matrix with diagonal entries equal to the regularisation term linear_lambda
+ // g is the vector of gradients
+ // the subscript _T denotes the transpose
+
+ // create array of pointers to raw data, and coefficient matrices, for each leaf
+ std::vector> leaf_features;
+ std::vector leaf_num_features;
+ std::vector> raw_data_ptr;
+ int max_num_features = 0;
+ for (int i = 0; i < num_leaves; ++i) {
+ std::vector raw_features;
+ if (is_refit) {
+ raw_features = tree->LeafFeatures(i);
+ } else {
+ raw_features = tree->branch_features(i);
+ }
+ std::sort(raw_features.begin(), raw_features.end());
+ auto new_end = std::unique(raw_features.begin(), raw_features.end());
+ raw_features.erase(new_end, raw_features.end());
+ std::vector numerical_features;
+ std::vector data_ptr;
+ for (size_t j = 0; j < raw_features.size(); ++j) {
+ int feat = train_data_->InnerFeatureIndex(raw_features[j]);
+ auto bin_mapper = train_data_->FeatureBinMapper(feat);
+ if (bin_mapper->bin_type() == BinType::NumericalBin) {
+ numerical_features.push_back(feat);
+ data_ptr.push_back(train_data_->raw_index(feat));
+ }
+ }
+ leaf_features.push_back(numerical_features);
+ raw_data_ptr.push_back(data_ptr);
+ leaf_num_features.push_back(numerical_features.size());
+ if (static_cast(numerical_features.size()) > max_num_features) {
+ max_num_features = numerical_features.size();
+ }
+ }
+ // clear the coefficient matrices
+#pragma omp parallel for schedule(static)
+ for (int i = 0; i < num_threads; ++i) {
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ int num_feat = leaf_features[leaf_num].size();
+ std::fill(XTHX_by_thread_[i][leaf_num].begin(), XTHX_by_thread_[i][leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0);
+ std::fill(XTg_by_thread_[i][leaf_num].begin(), XTg_by_thread_[i][leaf_num].begin() + num_feat + 1, 0);
+ }
+ }
+#pragma omp parallel for schedule(static)
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ int num_feat = leaf_features[leaf_num].size();
+ std::fill(XTHX_[leaf_num].begin(), XTHX_[leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0);
+ std::fill(XTg_[leaf_num].begin(), XTg_[leaf_num].begin() + num_feat + 1, 0);
+ }
+ std::vector> num_nonzero;
+ for (int i = 0; i < num_threads; ++i) {
+ if (HAS_NAN) {
+ num_nonzero.push_back(std::vector(num_leaves, 0));
+ }
+ }
+ OMP_INIT_EX();
+#pragma omp parallel if (num_data_ > 1024)
+ {
+ std::vector curr_row(max_num_features + 1);
+ int tid = omp_get_thread_num();
+#pragma omp for schedule(static)
+ for (int i = 0; i < num_data_; ++i) {
+ OMP_LOOP_EX_BEGIN();
+ int leaf_num = leaf_map_[i];
+ if (leaf_num < 0) {
+ continue;
+ }
+ bool nan_found = false;
+ int num_feat = leaf_num_features[leaf_num];
+ for (int feat = 0; feat < num_feat; ++feat) {
+ if (HAS_NAN) {
+ float val = raw_data_ptr[leaf_num][feat][i];
+ if (std::isnan(val)) {
+ nan_found = true;
+ break;
+ }
+ num_nonzero[tid][leaf_num] += 1;
+ curr_row[feat] = val;
+ } else {
+ curr_row[feat] = raw_data_ptr[leaf_num][feat][i];
+ }
+ }
+ if (HAS_NAN) {
+ if (nan_found) {
+ continue;
+ }
+ }
+ curr_row[num_feat] = 1.0;
+ double h = hessians[i];
+ double g = gradients[i];
+ int j = 0;
+ for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) {
+ double f1_val = curr_row[feat1];
+ XTg_by_thread_[tid][leaf_num][feat1] += f1_val * g;
+ f1_val *= h;
+ for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) {
+ XTHX_by_thread_[tid][leaf_num][j] += f1_val * curr_row[feat2];
+ ++j;
+ }
+ }
+ OMP_LOOP_EX_END();
+ }
+ }
+ OMP_THROW_EX();
+ auto total_nonzero = std::vector(tree->num_leaves());
+ // aggregate results from different threads
+ for (int tid = 0; tid < num_threads; ++tid) {
+#pragma omp parallel for schedule(static)
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ int num_feat = leaf_features[leaf_num].size();
+ for (int j = 0; j < (num_feat + 1) * (num_feat + 2) / 2; ++j) {
+ XTHX_[leaf_num][j] += XTHX_by_thread_[tid][leaf_num][j];
+ }
+ for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) {
+ XTg_[leaf_num][feat1] += XTg_by_thread_[tid][leaf_num][feat1];
+ }
+ if (HAS_NAN) {
+ total_nonzero[leaf_num] += num_nonzero[tid][leaf_num];
+ }
+ }
+ }
+ if (!HAS_NAN) {
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num);
+ }
+ }
+ double shrinkage = tree->shrinkage();
+ double decay_rate = config_->refit_decay_rate;
+ // copy into eigen matrices and solve
+#pragma omp parallel for schedule(static)
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ if (total_nonzero[leaf_num] < static_cast(leaf_features[leaf_num].size()) + 1) {
+ if (is_refit) {
+ double old_const = tree->LeafConst(leaf_num);
+ tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * tree->LeafOutput(leaf_num) * shrinkage);
+ tree->SetLeafCoeffs(leaf_num, std::vector(leaf_features[leaf_num].size(), 0));
+ tree->SetLeafFeaturesInner(leaf_num, leaf_features[leaf_num]);
+ } else {
+ tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num));
+ }
+ continue;
+ }
+ int num_feat = leaf_features[leaf_num].size();
+ Eigen::MatrixXd XTHX_mat(num_feat + 1, num_feat + 1);
+ Eigen::MatrixXd XTg_mat(num_feat + 1, 1);
+ int j = 0;
+ for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) {
+ for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) {
+ XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j];
+ XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2);
+ if ((feat1 == feat2) && (feat1 < num_feat)) {
+ XTHX_mat(feat1, feat2) += config_->linear_lambda;
+ }
+ ++j;
+ }
+ XTg_mat(feat1) = XTg_[leaf_num][feat1];
+ }
+ Eigen::MatrixXd coeffs = - XTHX_mat.fullPivLu().inverse() * XTg_mat;
+ std::vector coeffs_vec;
+ std::vector features_new;
+ std::vector old_coeffs = tree->LeafCoeffs(leaf_num);
+ for (size_t i = 0; i < leaf_features[leaf_num].size(); ++i) {
+ if (is_refit) {
+ features_new.push_back(leaf_features[leaf_num][i]);
+ coeffs_vec.push_back(decay_rate * old_coeffs[i] + (1.0 - decay_rate) * coeffs(i) * shrinkage);
+ } else {
+ if (coeffs(i) < -kZeroThreshold || coeffs(i) > kZeroThreshold) {
+ coeffs_vec.push_back(coeffs(i));
+ int feat = leaf_features[leaf_num][i];
+ features_new.push_back(feat);
+ }
+ }
+ }
+ // update the tree properties
+ tree->SetLeafFeaturesInner(leaf_num, features_new);
+ std::vector features_raw(features_new.size());
+ for (size_t i = 0; i < features_new.size(); ++i) {
+ features_raw[i] = train_data_->RealFeatureIndex(features_new[i]);
+ }
+ tree->SetLeafFeatures(leaf_num, features_raw);
+ tree->SetLeafCoeffs(leaf_num, coeffs_vec);
+ if (is_refit) {
+ double old_const = tree->LeafConst(leaf_num);
+ tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * coeffs(num_feat) * shrinkage);
+ } else {
+ tree->SetLeafConst(leaf_num, coeffs(num_feat));
+ }
+ }
+}
+#endif // LGB_R_BUILD
+} // namespace LightGBM
diff --git a/src/treelearner/linear_tree_learner.h b/src/treelearner/linear_tree_learner.h
new file mode 100644
index 000000000000..88bc7f91ede1
--- /dev/null
+++ b/src/treelearner/linear_tree_learner.h
@@ -0,0 +1,128 @@
+/*!
+ * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License. See LICENSE file in the project root for license information.
+ */
+#ifndef LIGHTGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_
+#define LIGHTGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "serial_tree_learner.h"
+
+namespace LightGBM {
+
+class LinearTreeLearner: public SerialTreeLearner {
+ public:
+ explicit LinearTreeLearner(const Config* config) : SerialTreeLearner(config) {}
+
+ void Init(const Dataset* train_data, bool is_constant_hessian) override;
+
+ void InitLinear(const Dataset* train_data, const int max_leaves) override;
+
+ Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override;
+
+ /*! \brief Create array mapping dataset to leaf index, used for linear trees */
+ void GetLeafMap(Tree* tree) const;
+
+ template
+ void CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const;
+
+ Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override;
+
+ Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred,
+ const score_t* gradients, const score_t* hessians) const override;
+
+ void AddPredictionToScore(const Tree* tree,
+ double* out_score) const override {
+ CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
+ bool has_nan = false;
+ if (any_nan_) {
+ for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
+ // use split_feature because split_feature_inner doesn't work when refitting existing tree
+ if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
+ has_nan = true;
+ break;
+ }
+ }
+ }
+ if (has_nan) {
+ AddPredictionToScoreInner(tree, out_score);
+ } else {
+ AddPredictionToScoreInner(tree, out_score);
+ }
+ }
+
+ template
+ void AddPredictionToScoreInner(const Tree* tree, double* out_score) const {
+ int num_leaves = tree->num_leaves();
+ std::vector leaf_const(num_leaves);
+ std::vector> leaf_coeff(num_leaves);
+ std::vector> feat_ptr(num_leaves);
+ std::vector leaf_output(num_leaves);
+ std::vector leaf_num_features(num_leaves);
+ for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
+ leaf_const[leaf_num] = tree->LeafConst(leaf_num);
+ leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num);
+ leaf_output[leaf_num] = tree->LeafOutput(leaf_num);
+ for (int feat : tree->LeafFeaturesInner(leaf_num)) {
+ feat_ptr[leaf_num].push_back(train_data_->raw_index(feat));
+ }
+ leaf_num_features[leaf_num] = feat_ptr[leaf_num].size();
+ }
+ OMP_INIT_EX();
+#pragma omp parallel for schedule(static) if (num_data_ > 1024)
+ for (int i = 0; i < num_data_; ++i) {
+ OMP_LOOP_EX_BEGIN();
+ int leaf_num = leaf_map_[i];
+ if (leaf_num < 0) {
+ continue;
+ }
+ double output = leaf_const[leaf_num];
+ int num_feat = leaf_num_features[leaf_num];
+ if (HAS_NAN) {
+ bool nan_found = false;
+ for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) {
+ float val = feat_ptr[leaf_num][feat_ind][i];
+ if (std::isnan(val)) {
+ nan_found = true;
+ break;
+ }
+ output += val * leaf_coeff[leaf_num][feat_ind];
+ }
+ if (nan_found) {
+ out_score[i] += leaf_output[leaf_num];
+ } else {
+ out_score[i] += output;
+ }
+ } else {
+ for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) {
+ output += feat_ptr[leaf_num][feat_ind][i] * leaf_coeff[leaf_num][feat_ind];
+ }
+ out_score[i] += output;
+ }
+ OMP_LOOP_EX_END();
+ }
+ OMP_THROW_EX();
+ }
+
+ protected:
+ /*! \brief whether numerical features contain any nan values */
+ std::vector contains_nan_;
+ /*! whether any numerical feature contains a nan value */
+ bool any_nan_;
+ /*! \brief map dataset to leaves */
+ mutable std::vector leaf_map_;
+ /*! \brief temporary storage for calculating linear model coefficients */
+ mutable std::vector> XTHX_;
+ mutable std::vector> XTg_;
+ mutable std::vector>> XTHX_by_thread_;
+ mutable std::vector>> XTg_by_thread_;
+};
+
+} // namespace LightGBM
+#endif // LightGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index cfa7d9b38b99..2fd7fcba5859 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -155,7 +155,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) {
constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves, train_data_->num_features()));
}
-Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
+Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool /*is_first_tree*/) {
Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
gradients_ = gradients;
hessians_ = hessians;
@@ -172,7 +172,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
BeforeTrain();
bool track_branch_features = !(config_->interaction_constraints_vector.empty());
- auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features));
+ auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features, false));
auto tree_ptr = tree.get();
constraints_->ShareTreePointer(tree_ptr);
@@ -203,6 +203,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
Split(tree_ptr, best_leaf, &left_leaf, &right_leaf);
cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
}
+
Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth);
return tree.release();
}
@@ -242,7 +243,8 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
return tree.release();
}
-Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, const score_t* gradients, const score_t *hessians) {
+Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred,
+ const score_t* gradients, const score_t *hessians) const {
data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
return FitByExistingTree(old_tree, gradients, hessians);
}
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index fd06c235a8bf..1a0bda53b5db 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -39,6 +39,7 @@ using json11::Json;
/*! \brief forward declaration */
class CostEfficientGradientBoosting;
+
/*!
* \brief Used for learning a tree by single machine
*/
@@ -74,12 +75,12 @@ class SerialTreeLearner: public TreeLearner {
}
}
- Tree* Train(const score_t* gradients, const score_t *hessians) override;
+ Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override;
Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override;
Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred,
- const score_t* gradients, const score_t* hessians) override;
+ const score_t* gradients, const score_t* hessians) const override;
void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
if (subset == nullptr) {
@@ -96,10 +97,10 @@ class SerialTreeLearner: public TreeLearner {
void AddPredictionToScore(const Tree* tree,
double* out_score) const override {
+ CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
if (tree->num_leaves() <= 1) {
return;
}
- CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
#pragma omp parallel for schedule(static, 1)
for (int i = 0; i < tree->num_leaves(); ++i) {
double output = static_cast(tree->LeafOutput(i));
@@ -189,7 +190,6 @@ class SerialTreeLearner: public TreeLearner {
FeatureHistogram* smaller_leaf_histogram_array_;
/*! \brief pointer to histograms array of larger leaf */
FeatureHistogram* larger_leaf_histogram_array_;
-
/*! \brief store best split points for all leaves */
std::vector best_split_per_leaf_;
/*! \brief store best split per feature for all leaves */
diff --git a/src/treelearner/tree_learner.cpp b/src/treelearner/tree_learner.cpp
index ab009a0b1008..04184932a335 100644
--- a/src/treelearner/tree_learner.cpp
+++ b/src/treelearner/tree_learner.cpp
@@ -8,13 +8,22 @@
#include "gpu_tree_learner.h"
#include "parallel_tree_learner.h"
#include "serial_tree_learner.h"
+#include "linear_tree_learner.h"
namespace LightGBM {
-TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) {
+TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type,
+ const Config* config) {
if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) {
- return new SerialTreeLearner(config);
+ if (config->linear_tree) {
+#ifdef LGB_R_BUILD
+ Log::Fatal("Linear tree learner does not work with R package.");
+#endif // LGB_R_BUILD
+ return new LinearTreeLearner(config);
+ } else {
+ return new SerialTreeLearner(config);
+ }
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner(config);
} else if (learner_type == std::string("data")) {
diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py
index b499ca7b1ea2..ed410e90088b 100644
--- a/tests/python_package_test/test_basic.py
+++ b/tests/python_package_test/test_basic.py
@@ -99,6 +99,38 @@ def test_chunked_dataset(self):
train_data.construct()
valid_data.construct()
+ def test_chunked_dataset_linear(self):
+ X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
+ random_state=2)
+ chunk_size = X_train.shape[0] // 10 + 1
+ X_train = [X_train[i * chunk_size:(i + 1) * chunk_size, :] for i in range(X_train.shape[0] // chunk_size + 1)]
+ X_test = [X_test[i * chunk_size:(i + 1) * chunk_size, :] for i in range(X_test.shape[0] // chunk_size + 1)]
+ params = {"bin_construct_sample_cnt": 100, 'linear_tree': True}
+ train_data = lgb.Dataset(X_train, label=y_train, params=params)
+ valid_data = train_data.create_valid(X_test, label=y_test, params=params)
+ train_data.construct()
+ valid_data.construct()
+
+ def test_save_and_load_linear(self):
+ X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
+ random_state=2)
+ X_train = np.concatenate([np.ones((X_train.shape[0], 1)), X_train], 1)
+ X_train[:X_train.shape[0] // 2, 0] = 0
+ y_train[:X_train.shape[0] // 2] = 1
+ params = {'linear_tree': True}
+ train_data = lgb.Dataset(X_train, label=y_train, params=params)
+ est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0])
+ pred1 = est.predict(X_train)
+ train_data.save_binary('temp_dataset.bin')
+ train_data_2 = lgb.Dataset('temp_dataset.bin')
+ est = lgb.train(params, train_data_2, num_boost_round=10)
+ pred2 = est.predict(X_train)
+ np.testing.assert_allclose(pred1, pred2)
+ est.save_model('temp_model.txt')
+ est2 = lgb.Booster(model_file='temp_model.txt')
+ pred3 = est2.predict(X_train)
+ np.testing.assert_allclose(pred2, pred3)
+
def test_subset_group(self):
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train'))
diff --git a/tests/python_package_test/test_consistency.py b/tests/python_package_test/test_consistency.py
index 782dac4368e3..ce6223d13935 100644
--- a/tests/python_package_test/test_consistency.py
+++ b/tests/python_package_test/test_consistency.py
@@ -78,6 +78,18 @@ def test_binary(self):
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
+ def test_binary_linear(self):
+ fd = FileLoader('../../examples/binary_classification', 'binary', 'train_linear.conf')
+ X_train, y_train, _ = fd.load_dataset('.train')
+ X_test, _, X_test_fn = fd.load_dataset('.test')
+ weight_train = fd.load_field('.train.weight')
+ lgb_train = lgb.Dataset(X_train, y_train, params=fd.params, weight=weight_train)
+ gbm = lgb.LGBMClassifier(**fd.params)
+ gbm.fit(X_train, y_train, sample_weight=weight_train)
+ sk_pred = gbm.predict_proba(X_test)[:, 1]
+ fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
+ fd.file_load_check(lgb_train, '.train')
+
def test_multiclass(self):
fd = FileLoader('../../examples/multiclass_classification', 'multiclass')
X_train, y_train, _ = fd.load_dataset('.train')
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index 84f5a8cc1071..faed1fce28b4 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -2420,6 +2420,86 @@ def test_interaction_constraints(self):
[1] + list(range(2, num_features))]),
train_data, num_boost_round=10)
+ def test_linear(self):
+ # check that setting boosting=gbdt_linear fits better than boosting=gbdt when data has linear relationship
+ np.random.seed(0)
+ x = np.arange(0, 100, 0.1)
+ y = 2 * x + np.random.normal(0, 0.1, len(x))
+ lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
+ params = {'verbose': -1,
+ 'metric': 'mse',
+ 'seed': 0,
+ 'num_leaves': 2}
+ est = lgb.train(params, lgb_train, num_boost_round=10)
+ pred1 = est.predict(x[:, np.newaxis])
+ lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
+ res = {}
+ est = lgb.train(dict(params, linear_tree=True), lgb_train, num_boost_round=10, evals_result=res,
+ valid_sets=[lgb_train], valid_names=['train'])
+ pred2 = est.predict(x[:, np.newaxis])
+ np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred2), atol=10**(-1))
+ self.assertLess(mean_squared_error(y, pred2), mean_squared_error(y, pred1))
+ # test again with nans in data
+ x[:10] = np.nan
+ lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
+ est = lgb.train(params, lgb_train, num_boost_round=10)
+ pred1 = est.predict(x[:, np.newaxis])
+ lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
+ res = {}
+ est = lgb.train(dict(params, linear_tree=True), lgb_train, num_boost_round=10, evals_result=res,
+ valid_sets=[lgb_train], valid_names=['train'])
+ pred2 = est.predict(x[:, np.newaxis])
+ np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred2), atol=10**(-1))
+ self.assertLess(mean_squared_error(y, pred2), mean_squared_error(y, pred1))
+ # test again with bagging
+ res = {}
+ est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train,
+ num_boost_round=10, evals_result=res, valid_sets=[lgb_train], valid_names=['train'])
+ pred = est.predict(x[:, np.newaxis])
+ np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred), atol=10**(-1))
+ # test with a feature that has only one non-nan value
+ x = np.concatenate([np.ones([x.shape[0], 1]), x[:, np.newaxis]], 1)
+ x[500:, 1] = np.nan
+ y[500:] += 10
+ lgb_train = lgb.Dataset(x, label=y)
+ res = {}
+ est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train,
+ num_boost_round=10, evals_result=res, valid_sets=[lgb_train], valid_names=['train'])
+ pred = est.predict(x)
+ np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred), atol=10**(-1))
+ # test with a categorical feature
+ x[:250, 0] = 0
+ y[:250] += 10
+ lgb_train = lgb.Dataset(x, label=y)
+ est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train,
+ num_boost_round=10, categorical_feature=[0])
+ # test refit: same results on same data
+ est2 = est.refit(x, label=y)
+ p1 = est.predict(x)
+ p2 = est2.predict(x)
+ self.assertLess(np.mean(np.abs(p1 - p2)), 2)
+ # test refit with save and load
+ est.save_model('temp_model.txt')
+ est2 = lgb.Booster(model_file='temp_model.txt')
+ est2 = est2.refit(x, label=y)
+ p1 = est.predict(x)
+ p2 = est2.predict(x)
+ self.assertLess(np.mean(np.abs(p1 - p2)), 2)
+ # test refit: different results training on different data
+ est2 = est.refit(x[:100, :], label=y[:100])
+ p3 = est2.predict(x)
+ self.assertGreater(np.mean(np.abs(p2 - p1)), np.abs(np.max(p3 - p1)))
+ # test when num_leaves - 1 < num_features and when num_leaves - 1 > num_features
+ X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2)
+ params = {'linear_tree': True,
+ 'verbose': -1,
+ 'metric': 'mse',
+ 'seed': 0}
+ train_data = lgb.Dataset(X_train, label=y_train, params=dict(params, num_leaves=2))
+ est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0])
+ train_data = lgb.Dataset(X_train, label=y_train, params=dict(params, num_leaves=60))
+ est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0])
+
def test_predict_with_start_iteration(self):
def inner_test(X, y, params, early_stopping_rounds):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
diff --git a/windows/LightGBM.vcxproj b/windows/LightGBM.vcxproj
index 6b37ca53b728..5c01a8cdb932 100644
--- a/windows/LightGBM.vcxproj
+++ b/windows/LightGBM.vcxproj
@@ -112,6 +112,7 @@
Disabled
MultiThreadedDebugDLL
true
+ $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories)
@@ -134,6 +135,7 @@
Disabled
MultiThreadedDebugDLL
true
+ $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories)
@@ -153,6 +155,7 @@
Disabled
MultiThreadedDebugDLL
true
+ $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories)
@@ -175,6 +178,7 @@
true
MultiThreadedDLL
true
+ $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories)
@@ -201,6 +205,7 @@
true
true
true
+ $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories)
@@ -221,6 +226,7 @@
true
true
true
+ $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories)
@@ -287,6 +293,7 @@
+
@@ -321,6 +328,7 @@
+
@@ -328,4 +336,4 @@
-
+
\ No newline at end of file