diff --git a/.ci/test.sh b/.ci/test.sh index 4cd181790ad7..776bf7aa0287 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -52,8 +52,8 @@ if [[ $TRAVIS == "true" ]] && [[ $TASK == "lint" ]]; then "r-lintr>=2.0" pip install --user cpplint echo "Linting Python code" - pycodestyle --ignore=E501,W503 --exclude=./compute,./.nuget,./external_libs . || exit -1 - pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1 + pycodestyle --ignore=E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1 + pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|^eigen|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1 echo "Linting R code" Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1 echo "Linting C++ code" diff --git a/.gitmodules b/.gitmodules index 8f1772cd19fa..ccab67d4a1f5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,9 @@ [submodule "include/boost/compute"] path = compute url = https://github.com/boostorg/compute +[submodule "eigen"] + path = eigen + url = https://gitlab.com/libeigen/eigen.git [submodule "external_libs/fmt"] path = external_libs/fmt url = https://github.com/fmtlib/fmt.git diff --git a/CMakeLists.txt b/CMakeLists.txt index c8559c966ab1..de0dd6573199 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,9 @@ if(USE_SWIG) endif() endif(USE_SWIG) +SET(EIGEN_DIR "${PROJECT_SOURCE_DIR}/eigen") +include_directories(${EIGEN_DIR}) + if(__BUILD_FOR_R) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") find_package(LibR REQUIRED) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 0c580ebe36bf..2490ba0757df 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -48,6 +48,7 @@ OBJECTS = \ treelearner/data_parallel_tree_learner.o \ treelearner/feature_parallel_tree_learner.o \ treelearner/gpu_tree_learner.o \ + treelearner/linear_tree_learner.o \ treelearner/serial_tree_learner.o \ treelearner/tree_learner.o \ treelearner/voting_parallel_tree_learner.o \ diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in index 952508e16a7d..0fb2de926905 100644 --- a/R-package/src/Makevars.win.in +++ b/R-package/src/Makevars.win.in @@ -49,6 +49,7 @@ OBJECTS = \ treelearner/data_parallel_tree_learner.o \ treelearner/feature_parallel_tree_learner.o \ treelearner/gpu_tree_learner.o \ + treelearner/linear_tree_learner.o \ treelearner/serial_tree_learner.o \ treelearner/tree_learner.o \ treelearner/voting_parallel_tree_learner.o \ diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 22e98217d614..1287c1354331 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -117,6 +117,26 @@ Core Parameters - **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations +- ``linear_tree`` :raw-html:`🔗︎`, default = ``false``, type = bool + + - fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner + + - tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant + + - the linear model at each leaf includes all the numerical features in that leaf's branch + + - categorical features are used for splits as normal but are not used in the linear models + + - missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0`` + + - it is recommended to rescale data before training so that features have similar mean and standard deviation + + - not yet supported in R-package + + - ``regression_l1`` objective is not supported with linear tree boosting + + - setting ``linear_tree = True`` significantly increases the memory use of LightGBM + - ``data`` :raw-html:`🔗︎`, default = ``""``, type = string, aliases: ``train``, ``train_data``, ``train_data_file``, ``data_filename`` - path of training data, LightGBM will train from this data @@ -384,6 +404,10 @@ Learning Control Parameters - L2 regularization +- ``linear_lambda`` :raw-html:`🔗︎`, default = ``0.0``, type = double, constraints: ``linear_lambda >= 0.0`` + + - Linear tree regularisation, the parameter `lambda` in Eq 3 of + - ``min_gain_to_split`` :raw-html:`🔗︎`, default = ``0.0``, type = double, aliases: ``min_split_gain``, constraints: ``min_gain_to_split >= 0.0`` - the minimal gain to perform split diff --git a/docs/conf.py b/docs/conf.py index aa1cb2dc3610..e30566c6fc46 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -202,6 +202,7 @@ def generate_doxygen_xml(app): "SKIP_FUNCTION_MACROS=NO", "SORT_BRIEF_DOCS=YES", "WARN_AS_ERROR=YES", + "EXCLUDE_PATTERNS=*/eigen/*" ] doxygen_input = '\n'.join(doxygen_args) doxygen_input = bytes(doxygen_input, "utf-8") diff --git a/eigen b/eigen new file mode 160000 index 000000000000..8ba1b0f41a79 --- /dev/null +++ b/eigen @@ -0,0 +1 @@ +Subproject commit 8ba1b0f41a7950dc3e1d4ed75859e36c73311235 diff --git a/examples/binary_classification/train_linear.conf b/examples/binary_classification/train_linear.conf new file mode 100644 index 000000000000..616d5fc39e35 --- /dev/null +++ b/examples/binary_classification/train_linear.conf @@ -0,0 +1,113 @@ +# task type, support train and predict +task = train + +# boosting type, support gbdt for now, alias: boosting, boost +boosting_type = gbdt + +# application type, support following application +# regression , regression task +# binary , binary classification task +# lambdarank , lambdarank task +# alias: application, app +objective = binary + +linear_tree = true + +# eval metrics, support multi metric, delimite by ',' , support following metrics +# l1 +# l2 , default metric for regression +# ndcg , default metric for lambdarank +# auc +# binary_logloss , default metric for binary +# binary_error +metric = binary_logloss,auc + +# frequence for metric output +metric_freq = 1 + +# true if need output metric for training data, alias: tranining_metric, train_metric +is_training_metric = true + +# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy. +max_bin = 255 + +# training data +# if exsting weight file, should name to "binary.train.weight" +# alias: train_data, train +data = binary.train + +# validation data, support multi validation data, separated by ',' +# if exsting weight file, should name to "binary.test.weight" +# alias: valid, test, test_data, +valid_data = binary.test + +# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds +num_trees = 100 + +# shrinkage rate , alias: shrinkage_rate +learning_rate = 0.1 + +# number of leaves for one tree, alias: num_leaf +num_leaves = 63 + +# type of tree learner, support following types: +# serial , single machine version +# feature , use feature parallel to train +# data , use data parallel to train +# voting , use voting based parallel to train +# alias: tree +tree_learner = serial + +# number of threads for multi-threading. One thread will use one CPU, defalut is setted to #cpu. +# num_threads = 8 + +# feature sub-sample, will random select 80% feature to train on each iteration +# alias: sub_feature +feature_fraction = 0.8 + +# Support bagging (data sub-sample), will perform bagging every 5 iterations +bagging_freq = 5 + +# Bagging farction, will random select 80% data on bagging +# alias: sub_row +bagging_fraction = 0.8 + +# minimal number data for one leaf, use this to deal with over-fit +# alias : min_data_per_leaf, min_data +min_data_in_leaf = 50 + +# minimal sum hessians for one leaf, use this to deal with over-fit +min_sum_hessian_in_leaf = 5.0 + +# save memory and faster speed for sparse feature, alias: is_sparse +is_enable_sparse = true + +# when data is bigger than memory size, set this to true. otherwise set false will have faster speed +# alias: two_round_loading, two_round +use_two_round_loading = false + +# true if need to save data to binary file and application will auto load data from binary file next time +# alias: is_save_binary, save_binary +is_save_binary_file = false + +# output model file +output_model = LightGBM_model.txt + +# support continuous train from trained gbdt model +# input_model= trained_model.txt + +# output prediction file for predict task +# output_result= prediction.txt + + +# number of machines in parallel training, alias: num_machine +num_machines = 1 + +# local listening port in parallel training, alias: local_port +local_listen_port = 12400 + +# machines list file for parallel training, alias: mlist +machine_list_file = mlist.txt + +# force splits +# forced_splits = forced_splits.json diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index ca198d37671e..ddbcdbc18e44 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -312,6 +312,8 @@ class LIGHTGBM_EXPORT Boosting { * \return The boosting object */ static Boosting* CreateBoosting(const std::string& type, const char* filename); + + virtual bool IsLinear() const { return false; } }; class GBDTBase : public Boosting { diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 521ea821eccd..c16143d82160 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -389,6 +389,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, int* out); +/*! +* \brief Get boolean representing whether booster is fitting linear trees. +* \param handle Handle of dataset +* \param[out] out The address to hold linear indicator +* \return 0 when succeed, -1 when failure happens +*/ +LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out); + /*! * \brief Add features from ``source`` to ``target``. * \param target The handle of the dataset to add features to diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 8c017ad26926..43f8edd99bf8 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -148,6 +148,17 @@ struct Config { // descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations std::string boosting = "gbdt"; + // desc = fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner + // descl2 = tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant + // descl2 = the linear model at each leaf includes all the numerical features in that leaf's branch + // descl2 = categorical features are used for splits as normal but are not used in the linear models + // descl2 = missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0`` + // descl2 = it is recommended to rescale data before training so that features have similar mean and standard deviation + // descl2 = not yet supported in R-package + // descl2 = ``regression_l1`` objective is not supported with linear tree boosting + // descl2 = setting ``linear_tree = True`` significantly increases the memory use of LightGBM + bool linear_tree = false; + // alias = train, train_data, train_data_file, data_filename // desc = path of training data, LightGBM will train from this data // desc = **Note**: can be used only in CLI version @@ -366,6 +377,10 @@ struct Config { // desc = L2 regularization double lambda_l2 = 0.0; + // check = >=0.0 + // desc = Linear tree regularisation, the parameter `lambda` in Eq 3 of + double linear_lambda = 0.0; + // alias = min_split_gain // check = >=0.0 // desc = the minimal gain to perform split diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 44c4196544b6..b1ee7a40dd40 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -337,6 +337,12 @@ class Dataset { const int group = feature2group_[feature_idx]; const int sub_feature = feature2subfeature_[feature_idx]; feature_groups_[group]->PushData(tid, sub_feature, row_idx, feature_values[i]); + if (has_raw_) { + int feat_ind = numeric_feature_map_[feature_idx]; + if (feat_ind >= 0) { + raw_data_[feat_ind][row_idx] = feature_values[i]; + } + } } } } @@ -352,13 +358,25 @@ class Dataset { const int group = feature2group_[feature_idx]; const int sub_feature = feature2subfeature_[feature_idx]; feature_groups_[group]->PushData(tid, sub_feature, row_idx, inner_data.second); + if (has_raw_) { + int feat_ind = numeric_feature_map_[feature_idx]; + if (feat_ind >= 0) { + raw_data_[feat_ind][row_idx] = inner_data.second; + } + } } } FinishOneRow(tid, row_idx, is_feature_added); } - inline void PushOneData(int tid, data_size_t row_idx, int group, int sub_feature, double value) { + inline void PushOneData(int tid, data_size_t row_idx, int group, int feature_idx, int sub_feature, double value) { feature_groups_[group]->PushData(tid, sub_feature, row_idx, value); + if (has_raw_) { + int feat_ind = numeric_feature_map_[feature_idx]; + if (feat_ind >= 0) { + raw_data_[feat_ind][row_idx] = value; + } + } } inline int RealFeatureIndex(int fidx) const { @@ -569,6 +587,9 @@ class Dataset { /*! \brief Get Number of used features */ inline int num_features() const { return num_features_; } + /*! \brief Get number of numeric features */ + inline int num_numeric_features() const { return num_numeric_features_; } + /*! \brief Get Number of feature groups */ inline int num_feature_groups() const { return num_groups_;} @@ -632,6 +653,31 @@ class Dataset { void AddFeaturesFrom(Dataset* other); + /*! \brief Get has_raw_ */ + inline bool has_raw() const { return has_raw_; } + + /*! \brief Set has_raw_ */ + inline void SetHasRaw(bool has_raw) { has_raw_ = has_raw; } + + /*! \brief Resize raw_data_ */ + inline void ResizeRaw(int num_rows) { + if (static_cast(raw_data_.size()) > num_numeric_features_) { + raw_data_.resize(num_numeric_features_); + } + for (size_t i = 0; i < raw_data_.size(); ++i) { + raw_data_[i].resize(num_rows); + } + int curr_size = static_cast(raw_data_.size()); + for (int i = curr_size; i < num_numeric_features_; ++i) { + raw_data_.push_back(std::vector(num_rows, 0)); + } + } + + /*! \brief Get pointer to raw_data_ feature */ + inline const float* raw_index(int feat_ind) const { + return raw_data_[numeric_feature_map_[feat_ind]].data(); + } + private: std::string data_filename_; /*! \brief Store used features */ @@ -668,6 +714,11 @@ class Dataset { bool use_missing_; bool zero_as_missing_; std::vector feature_need_push_zeros_; + std::vector> raw_data_; + bool has_raw_; + /*! map feature (inner index) to its index in the list of numeric (non-categorical) features */ + std::vector numeric_feature_map_; + int num_numeric_features_; }; } // namespace LightGBM diff --git a/include/LightGBM/dataset_loader.h b/include/LightGBM/dataset_loader.h index fde3d1c00e16..e72dd4910804 100644 --- a/include/LightGBM/dataset_loader.h +++ b/include/LightGBM/dataset_loader.h @@ -82,6 +82,8 @@ class DatasetLoader { std::vector feature_names_; /*! \brief Mapper from real feature index to used index*/ std::unordered_set categorical_features_; + /*! \brief Whether to store raw feature values */ + bool store_raw_; }; } // namespace LightGBM diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index fe46dc4f313e..d5d9b5282059 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -28,8 +28,9 @@ class Tree { * \brief Constructor * \param max_leaves The number of max leaves * \param track_branch_features Whether to keep track of ancestors of leaf nodes + * \param is_linear Whether the tree has linear models at each leaf */ - explicit Tree(int max_leaves, bool track_branch_features); + explicit Tree(int max_leaves, bool track_branch_features, bool is_linear); /*! * \brief Constructor, from a string @@ -148,9 +149,12 @@ class Tree { /*! \brief Get parent of specific leaf*/ inline int leaf_parent(int leaf_idx) const {return leaf_parent_[leaf_idx]; } - /*! \brief Get feature of specific split*/ + /*! \brief Get feature of specific split (original feature index)*/ inline int split_feature(int split_idx) const { return split_feature_[split_idx]; } + /*! \brief Get feature of specific split*/ + inline int split_feature_inner(int split_idx) const { return split_feature_inner_[split_idx]; } + /*! \brief Get features on leaf's branch*/ inline std::vector branch_features(int leaf) const { return branch_features_[leaf]; } @@ -168,10 +172,6 @@ class Tree { inline int right_child(int node_idx) const { return right_child_[node_idx]; } - inline int split_feature_inner(int node_idx) const { - return split_feature_inner_[node_idx]; - } - inline uint32_t threshold_in_bin(int node_idx) const { return threshold_in_bin_[node_idx]; } @@ -189,9 +189,21 @@ class Tree { for (int i = 0; i < num_leaves_ - 1; ++i) { leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] * rate); internal_value_[i] = MaybeRoundToZero(internal_value_[i] * rate); + if (is_linear_) { + leaf_const_[i] = MaybeRoundToZero(leaf_const_[i] * rate); + for (size_t j = 0; j < leaf_coeff_[i].size(); ++j) { + leaf_coeff_[i][j] = MaybeRoundToZero(leaf_coeff_[i][j] * rate); + } + } } leaf_value_[num_leaves_ - 1] = MaybeRoundToZero(leaf_value_[num_leaves_ - 1] * rate); + if (is_linear_) { + leaf_const_[num_leaves_ - 1] = MaybeRoundToZero(leaf_const_[num_leaves_ - 1] * rate); + for (size_t j = 0; j < leaf_coeff_[num_leaves_ - 1].size(); ++j) { + leaf_coeff_[num_leaves_ - 1][j] = MaybeRoundToZero(leaf_coeff_[num_leaves_ - 1][j] * rate); + } + } shrinkage_ *= rate; } @@ -205,6 +217,13 @@ class Tree { } leaf_value_[num_leaves_ - 1] = MaybeRoundToZero(leaf_value_[num_leaves_ - 1] + val); + if (is_linear_) { +#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048) + for (int i = 0; i < num_leaves_ - 1; ++i) { + leaf_const_[i] = MaybeRoundToZero(leaf_const_[i] + val); + } + leaf_const_[num_leaves_ - 1] = MaybeRoundToZero(leaf_const_[num_leaves_ - 1] + val); + } // force to 1.0 shrinkage_ = 1.0f; } @@ -213,6 +232,9 @@ class Tree { num_leaves_ = 1; shrinkage_ = 1.0f; leaf_value_[0] = val; + if (is_linear_) { + leaf_const_[0] = val; + } } /*! \brief Serialize this object to string*/ @@ -257,6 +279,47 @@ class Tree { int NextLeafId() const { return num_leaves_; } + /*! \brief Get the linear model constant term (bias) of one leaf */ + inline double LeafConst(int leaf) const { return leaf_const_[leaf]; } + + /*! \brief Get the linear model coefficients of one leaf */ + inline std::vector LeafCoeffs(int leaf) const { return leaf_coeff_[leaf]; } + + /*! \brief Get the linear model features of one leaf */ + inline std::vector LeafFeaturesInner(int leaf) const {return leaf_features_inner_[leaf]; } + + /*! \brief Get the linear model features of one leaf */ + inline std::vector LeafFeatures(int leaf) const {return leaf_features_[leaf]; } + + /*! \brief Set the linear model coefficients on one leaf */ + inline void SetLeafCoeffs(int leaf, const std::vector& output) { + leaf_coeff_[leaf].resize(output.size()); + for (size_t i = 0; i < output.size(); ++i) { + leaf_coeff_[leaf][i] = MaybeRoundToZero(output[i]); + } + } + + /*! \brief Set the linear model constant term (bias) on one leaf */ + inline void SetLeafConst(int leaf, double output) { + leaf_const_[leaf] = MaybeRoundToZero(output); + } + + /*! \brief Set the linear model features on one leaf */ + inline void SetLeafFeaturesInner(int leaf, const std::vector& features) { + leaf_features_inner_[leaf] = features; + } + + /*! \brief Set the linear model features on one leaf */ + inline void SetLeafFeatures(int leaf, const std::vector& features) { + leaf_features_[leaf] = features; + } + + inline bool is_linear() const { return is_linear_; } + + inline void SetIsLinear(bool is_linear) { + is_linear_ = is_linear; + } + private: std::string NumericalDecisionIfElse(int node) const; @@ -454,6 +517,16 @@ class Tree { std::vector> branch_features_; double shrinkage_; int max_depth_; + /*! \brief Tree has linear model at each leaf */ + bool is_linear_; + /*! \brief coefficients of linear models on leaves */ + std::vector> leaf_coeff_; + /*! \brief constant term (bias) of linear models on leaves */ + std::vector leaf_const_; + /* \brief features used in leaf linear models; indexing is relative to num_total_features_ */ + std::vector> leaf_features_; + /* \brief features used in leaf linear models; indexing is relative to used_features_ */ + std::vector> leaf_features_inner_; }; inline void Tree::Split(int leaf, int feature, int real_feature, @@ -501,20 +574,65 @@ inline void Tree::Split(int leaf, int feature, int real_feature, } inline double Tree::Predict(const double* feature_values) const { - if (num_leaves_ > 1) { - int leaf = GetLeaf(feature_values); - return LeafOutput(leaf); + if (is_linear_) { + int leaf = GetLeaf(feature_values); + double output = leaf_const_[leaf]; + bool nan_found = false; + for (size_t i = 0; i < leaf_features_[leaf].size(); ++i) { + int feat_raw = leaf_features_[leaf][i]; + double feat_val = feature_values[feat_raw]; + if (std::isnan(feat_val)) { + nan_found = true; + break; + } else { + output += leaf_coeff_[leaf][i] * feat_val; + } + } + if (nan_found) { + return LeafOutput(leaf); + } else { + return output; + } } else { - return leaf_value_[0]; + if (num_leaves_ > 1) { + int leaf = GetLeaf(feature_values); + return LeafOutput(leaf); + } else { + return leaf_value_[0]; + } } } inline double Tree::PredictByMap(const std::unordered_map& feature_values) const { - if (num_leaves_ > 1) { + if (is_linear_) { int leaf = GetLeafByMap(feature_values); - return LeafOutput(leaf); + double output = leaf_const_[leaf]; + bool nan_found = false; + for (size_t i = 0; i < leaf_features_[leaf].size(); ++i) { + int feat = leaf_features_[leaf][i]; + auto val_it = feature_values.find(feat); + if (val_it != feature_values.end()) { + double feat_val = val_it->second; + if (std::isnan(feat_val)) { + nan_found = true; + break; + } else { + output += leaf_coeff_[leaf][i] * feat_val; + } + } + } + if (nan_found) { + return LeafOutput(leaf); + } else { + return output; + } } else { - return leaf_value_[0]; + if (num_leaves_ > 1) { + int leaf = GetLeafByMap(feature_values); + return LeafOutput(leaf); + } else { + return leaf_value_[0]; + } } } diff --git a/include/LightGBM/tree_learner.h b/include/LightGBM/tree_learner.h index e0fb34890579..55343ad714ee 100644 --- a/include/LightGBM/tree_learner.h +++ b/include/LightGBM/tree_learner.h @@ -36,6 +36,9 @@ class TreeLearner { */ virtual void Init(const Dataset* train_data, bool is_constant_hessian) = 0; + /*! Initialise some temporary storage, only needed for the linear tree; needs to be a method of TreeLearner since we call it in GBDT::RefitTree */ + virtual void InitLinear(const Dataset* /*train_data*/, const int /*max_leaves*/) {} + virtual void ResetIsConstantHessian(bool is_constant_hessian) = 0; virtual void ResetTrainingData(const Dataset* train_data, @@ -53,9 +56,10 @@ class TreeLearner { * \brief training tree model on dataset * \param gradients The first order gradients * \param hessians The second order gradients + * \param is_first_tree If linear tree learning is enabled, first tree needs to be handled differently * \return A trained tree */ - virtual Tree* Train(const score_t* gradients, const score_t* hessians) = 0; + virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_first_tree) = 0; /*! * \brief use an existing tree to fit the new gradients and hessians. @@ -63,7 +67,7 @@ class TreeLearner { virtual Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const = 0; virtual Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, - const score_t* gradients, const score_t* hessians) = 0; + const score_t* gradients, const score_t* hessians) const = 0; /*! * \brief Set bagging data @@ -94,11 +98,12 @@ class TreeLearner { * \brief Create object of tree learner * \param learner_type Type of tree learner * \param device_type Type of tree learner + * \param booster_type Type of boosting * \param config config of tree */ static TreeLearner* CreateTreeLearner(const std::string& learner_type, - const std::string& device_type, - const Config* config); + const std::string& device_type, + const Config* config); }; } // namespace LightGBM diff --git a/include/LightGBM/utils/openmp_wrapper.h b/include/LightGBM/utils/openmp_wrapper.h index 71574ff894b6..1278d172a25d 100644 --- a/include/LightGBM/utils/openmp_wrapper.h +++ b/include/LightGBM/utils/openmp_wrapper.h @@ -78,6 +78,7 @@ class ThreadExceptionHelper { All #pragma omp should be ignored by the compiler **/ inline void omp_set_num_threads(int) {} inline int omp_get_num_threads() {return 1;} + inline int omp_get_max_threads() {return 1;} inline int omp_get_thread_num() {return 0;} inline int OMP_NUM_THREADS() { return 1; } #ifdef __cplusplus diff --git a/python-package/MANIFEST.in b/python-package/MANIFEST.in index 62bc7f64fcfb..cc0ab0626f91 100644 --- a/python-package/MANIFEST.in +++ b/python-package/MANIFEST.in @@ -10,6 +10,7 @@ include compile/compute/CMakeLists.txt recursive-include compile/compute/cmake * recursive-include compile/compute/include * recursive-include compile/compute/meta * +recursive-include compile/eigen * include compile/external_libs/fast_double_parser/CMakeLists.txt include compile/external_libs/fast_double_parser/LICENSE include compile/external_libs/fast_double_parser/LICENSE.BSL diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4e50fe0f5b81..a48350d9e280 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1011,6 +1011,7 @@ def get_params(self): "ignore_column", "is_enable_sparse", "label_column", + "linear_tree", "max_bin", "max_bin_by_feature", "min_data_in_bin", @@ -2999,8 +3000,13 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): predictor = self._to_predictor(copy.deepcopy(kwargs)) leaf_preds = predictor.predict(data, -1, pred_leaf=True) nrow, ncol = leaf_preds.shape - train_set = Dataset(data, label, silent=True) + out_is_linear = ctypes.c_bool(False) + _safe_call(_LIB.LGBM_BoosterGetLinear( + self.handle, + ctypes.byref(out_is_linear))) new_params = copy.deepcopy(self.params) + new_params["linear_tree"] = out_is_linear.value + train_set = Dataset(data, label, silent=True, params=new_params) new_params['refit_decay_rate'] = decay_rate new_booster = Booster(new_params, train_set) # Copy models diff --git a/python-package/setup.py b/python-package/setup.py index a2e8a0cf3560..32f4d12f92f5 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -65,6 +65,7 @@ def copy_files_helper(folder_name): if not os.path.isfile(os.path.join(CURRENT_DIR, '_IS_SOURCE_PACKAGE.txt')): copy_files_helper('include') copy_files_helper('src') + copy_files_helper('eigen') copy_files_helper('external_libs') if not os.path.exists(os.path.join(CURRENT_DIR, "compile", "windows")): os.makedirs(os.path.join(CURRENT_DIR, "compile", "windows")) diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 84b6c0a9bd3a..cf49060f74e1 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -40,6 +40,7 @@ GBDT::GBDT() bagging_runner_(0, bagging_rand_block_) { average_output_ = false; tree_learner_ = nullptr; + linear_tree_ = false; } GBDT::~GBDT() { @@ -88,7 +89,8 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective is_constant_hessian_ = GetIsConstHessian(objective_function); - tree_learner_ = std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get())); + tree_learner_ = std::unique_ptr(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, + config_.get())); // init tree learner tree_learner_->Init(train_data_, is_constant_hessian_); @@ -129,6 +131,10 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective class_need_train_[i] = objective_function_->ClassNeedTrain(i); } } + + if (config_->linear_tree) { + linear_tree_ = true; + } } void GBDT::AddValidDataset(const Dataset* valid_data, @@ -282,6 +288,19 @@ void GBDT::RefitTree(const std::vector>& tree_leaf_prediction) CHECK_EQ(static_cast(models_.size()), tree_leaf_prediction[0].size()); int num_iterations = static_cast(models_.size() / num_tree_per_iteration_); std::vector leaf_pred(num_data_); + if (linear_tree_) { + std::vector max_leaves_by_thread = std::vector(OMP_NUM_THREADS(), 0); + #pragma omp parallel for schedule(static) + for (int i = 0; i < static_cast(tree_leaf_prediction.size()); ++i) { + int tid = omp_get_thread_num(); + for (size_t j = 0; j < tree_leaf_prediction[i].size(); ++j) { + max_leaves_by_thread[tid] = std::max(max_leaves_by_thread[tid], tree_leaf_prediction[i][j]); + } + } + int max_leaves = *std::max_element(max_leaves_by_thread.begin(), max_leaves_by_thread.end()); + max_leaves += 1; + tree_learner_->InitLinear(train_data_, max_leaves); + } for (int iter = 0; iter < num_iterations; ++iter) { Boosting(); for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) { @@ -365,7 +384,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { bool should_continue = false; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { const size_t offset = static_cast(cur_tree_id) * num_data_; - std::unique_ptr new_tree(new Tree(2, false)); + std::unique_ptr new_tree(new Tree(2, false, false)); if (class_need_train_[cur_tree_id] && train_data_->num_features() > 0) { auto grad = gradients + offset; auto hess = hessians + offset; @@ -378,7 +397,8 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { grad = gradients_.data() + offset; hess = hessians_.data() + offset; } - new_tree.reset(tree_learner_->Train(grad, hess)); + bool is_first_tree = models_.size() < static_cast(num_tree_per_iteration_); + new_tree.reset(tree_learner_->Train(grad, hess, is_first_tree)); } if (new_tree->num_leaves() > 1) { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 0d38385d5f06..998bd5353ce2 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -44,6 +44,7 @@ class GBDT : public GBDTBase { */ ~GBDT(); + /*! * \brief Initialization logic * \param gbdt_config Config for boosting @@ -391,6 +392,8 @@ class GBDT : public GBDTBase { */ const char* SubModelName() const override { return "tree"; } + bool IsLinear() const override { return linear_tree_; } + protected: virtual bool GetIsConstHessian(const ObjectiveFunction* objective_function) { if (objective_function != nullptr) { @@ -530,6 +533,7 @@ class GBDT : public GBDTBase { std::vector bagging_rands_; ParallelPartitionRunner bagging_runner_; Json forced_splits_json_; + bool linear_tree_; }; } // namespace LightGBM diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index e5cec8b61db0..8b9a2b8b450b 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -581,6 +581,11 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { break; } else if (is_inparameter) { ss << cur_line << "\n"; + if (Common::StartsWith(cur_line, "[linear_tree: ")) { + int is_linear = 0; + Common::Atoi(cur_line.substr(14, 1).c_str(), &is_linear); + linear_tree_ = static_cast(is_linear); + } } } p += line_len; diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp index 0768fe10ca31..5a9eb226fef5 100644 --- a/src/boosting/rf.hpp +++ b/src/boosting/rf.hpp @@ -109,7 +109,7 @@ class RF : public GBDT { gradients = gradients_.data(); hessians = hessians_.data(); for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { - std::unique_ptr new_tree(new Tree(2, false)); + std::unique_ptr new_tree(new Tree(2, false, false)); size_t offset = static_cast(cur_tree_id)* num_data_; if (class_need_train_[cur_tree_id]) { auto grad = gradients + offset; @@ -125,7 +125,7 @@ class RF : public GBDT { hess = tmp_hess_.data(); } - new_tree.reset(tree_learner_->Train(grad, hess)); + new_tree.reset(tree_learner_->Train(grad, hess, false)); } if (new_tree->num_leaves() > 1) { diff --git a/src/c_api.cpp b/src/c_api.cpp index 5cd0eab3d77f..4cf9a5e5072a 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -289,6 +289,9 @@ class Booster { "You need to set `feature_pre_filter=false` to dynamically change " "the `min_data_in_leaf`."); } + if (new_param.count("linear_tree") && (new_config.linear_tree != old_config.linear_tree)) { + Log:: Fatal("Cannot change between gbdt_linear boosting and other boosting types after Dataset handle has been constructed."); + } } void ResetConfig(const char* parameters) { @@ -960,6 +963,9 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, API_BEGIN(); auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); + if (p_dataset->has_raw()) { + p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow); + } OMP_INIT_EX(); #pragma omp parallel for schedule(static) for (int i = 0; i < nrow; ++i) { @@ -990,14 +996,16 @@ int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); + if (p_dataset->has_raw()) { + p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow); + } OMP_INIT_EX(); #pragma omp parallel for schedule(static) for (int i = 0; i < nrow; ++i) { OMP_LOOP_EX_BEGIN(); const int tid = omp_get_thread_num(); auto one_row = get_row_fun(i); - p_dataset->PushOneRow(tid, - static_cast(start_row + i), one_row); + p_dataset->PushOneRow(tid, static_cast(start_row + i), one_row); OMP_LOOP_EX_END(); } OMP_THROW_EX(); @@ -1090,6 +1098,9 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, ret.reset(new Dataset(total_nrow)); ret->CreateValid( reinterpret_cast(reference)); + if (ret->has_raw()) { + ret->ResizeRaw(total_nrow); + } } int32_t start_row = 0; for (int j = 0; j < nmat; ++j) { @@ -1166,6 +1177,9 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ret.reset(new Dataset(nrow)); ret->CreateValid( reinterpret_cast(reference)); + if (ret->has_raw()) { + ret->ResizeRaw(nrow); + } } OMP_INIT_EX(); #pragma omp parallel for schedule(static) @@ -1234,6 +1248,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, ret.reset(new Dataset(nrow)); ret->CreateValid( reinterpret_cast(reference)); + if (ret->has_raw()) { + ret->ResizeRaw(nrow); + } } OMP_INIT_EX(); @@ -1326,12 +1343,12 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, row_idx = pair.first; // no more data if (row_idx < 0) { break; } - ret->PushOneData(tid, row_idx, group, sub_feature, pair.second); + ret->PushOneData(tid, row_idx, group, feature_idx, sub_feature, pair.second); } } else { for (int row_idx = 0; row_idx < nrow; ++row_idx) { auto val = col_it.Get(row_idx); - ret->PushOneData(tid, row_idx, group, sub_feature, val); + ret->PushOneData(tid, row_idx, group, feature_idx, sub_feature, val); } } OMP_LOOP_EX_END(); @@ -1600,6 +1617,13 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) { API_END(); } +int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + *out = ref_booster->GetBoosting()->IsLinear(); + API_END(); +} + int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); diff --git a/src/io/config.cpp b/src/io/config.cpp index 45987b6760c5..ce034505ee4b 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -335,13 +335,28 @@ void Config::CheckParamConflict() { Log::Warning("Although \"deterministic\" is set, the results ran by GPU may be non-deterministic."); } } - // force gpu_use_dp for CUDA if (device_type == std::string("cuda") && !gpu_use_dp) { Log::Warning("CUDA currently requires double precision calculations."); gpu_use_dp = true; } - + // linear tree learner must be serial type and cpu device + if (linear_tree) { + if (device_type == std::string("gpu")) { + device_type = "cpu"; + Log::Warning("Linear tree learner only works with CPU."); + } + if (tree_learner != std::string("serial")) { + tree_learner = "serial"; + Log::Warning("Linear tree learner must be serial."); + } + if (zero_as_missing) { + Log::Fatal("zero_as_missing must be false when fitting linear trees."); + } + if (objective == std::string("regresson_l1")) { + Log::Fatal("Cannot use regression_l1 objective when fitting linear trees."); + } + } // min_data_in_leaf must be at least 2 if path smoothing is active. This is because when the split is calculated // the count is calculated using the proportion of hessian in the leaf which is rounded up to nearest int, so it can // be 1 when there is actually no data in the leaf. In rare cases this can cause a bug because with path smoothing the diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index b4a55ce589ef..06c53e84268a 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -174,6 +174,7 @@ const std::unordered_set& Config::parameter_set() { "task", "objective", "boosting", + "linear_tree", "data", "valid", "num_iterations", @@ -205,6 +206,7 @@ const std::unordered_set& Config::parameter_set() { "max_delta_step", "lambda_l1", "lambda_l2", + "linear_lambda", "min_gain_to_split", "drop_rate", "max_drop", @@ -304,6 +306,8 @@ const std::unordered_set& Config::parameter_set() { void Config::GetMembersFromString(const std::unordered_map& params) { std::string tmp_str = ""; + GetBool(params, "linear_tree", &linear_tree); + GetString(params, "data", &data); if (GetString(params, "valid", &tmp_str)) { @@ -380,6 +384,9 @@ void Config::GetMembersFromString(const std::unordered_map>* bin_mappers, bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt; use_missing_ = io_config.use_missing; zero_as_missing_ = io_config.zero_as_missing; + has_raw_ = false; + if (io_config.linear_tree) { + has_raw_ = true; + } + numeric_feature_map_ = std::vector(num_features_, -1); + num_numeric_features_ = 0; + for (int i = 0; i < num_features_; ++i) { + if (FeatureBinMapper(i)->bin_type() == BinType::NumericalBin) { + numeric_feature_map_[i] = num_numeric_features_; + ++num_numeric_features_; + } + } } void Dataset::FinishLoad() { @@ -696,6 +710,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { feature_groups_.clear(); num_features_ = dataset->num_features_; num_groups_ = dataset->num_groups_; + has_raw_ = dataset->has_raw(); // copy feature bin mapper data for (int i = 0; i < num_groups_; ++i) { feature_groups_.emplace_back( @@ -722,7 +737,10 @@ void Dataset::CreateValid(const Dataset* dataset) { num_groups_ = num_features_; feature2group_.clear(); feature2subfeature_.clear(); - + has_raw_ = dataset->has_raw(); + numeric_feature_map_ = dataset->numeric_feature_map_; + num_numeric_features_ = dataset->num_numeric_features_; + // copy feature bin mapper data feature_need_push_zeros_.clear(); group_bin_boundaries_.clear(); uint64_t num_total_bin = 0; @@ -785,6 +803,17 @@ void Dataset::CopySubrow(const Dataset* fullset, metadata_.Init(fullset->metadata_, used_indices, num_used_indices); } is_finish_load_ = true; + numeric_feature_map_ = fullset->numeric_feature_map_; + num_numeric_features_ = fullset->num_numeric_features_; + if (has_raw_) { + ResizeRaw(num_used_indices); +#pragma omp parallel for schedule(static) + for (int i = 0; i < num_used_indices; ++i) { + for (int j = 0; j < num_numeric_features_; ++j) { + raw_data_[j][i] = fullset->raw_data_[j][used_indices[i]]; + } + } + } } bool Dataset::SetFloatField(const char* field_name, const float* field_data, @@ -922,8 +951,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { 2 * VirtualFileWriter::AlignedSize(sizeof(int) * num_groups_) + VirtualFileWriter::AlignedSize(sizeof(int32_t) * num_total_features_) + VirtualFileWriter::AlignedSize(sizeof(int)) * 3 + - VirtualFileWriter::AlignedSize(sizeof(bool)) * 2; - + VirtualFileWriter::AlignedSize(sizeof(bool)) * 3; // size of feature names for (int i = 0; i < num_total_features_; ++i) { size_of_header += @@ -947,6 +975,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { writer->AlignedWrite(&min_data_in_bin_, sizeof(min_data_in_bin_)); writer->AlignedWrite(&use_missing_, sizeof(use_missing_)); writer->AlignedWrite(&zero_as_missing_, sizeof(zero_as_missing_)); + writer->AlignedWrite(&has_raw_, sizeof(has_raw_)); writer->AlignedWrite(used_feature_map_.data(), sizeof(int) * num_total_features_); writer->AlignedWrite(&num_groups_, sizeof(num_groups_)); @@ -998,6 +1027,18 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { // write feature feature_groups_[i]->SaveBinaryToFile(writer.get()); } + + // write raw data; use row-major order so we can read row-by-row + if (has_raw_) { + for (int i = 0; i < num_data_; ++i) { + for (int j = 0; j < num_features_; ++j) { + int feat_ind = numeric_feature_map_[j]; + if (feat_ind > -1) { + writer->Write(&raw_data_[feat_ind][i], sizeof(float)); + } + } + } + } } } @@ -1264,6 +1305,9 @@ void Dataset::AddFeaturesFrom(Dataset* other) { "Cannot add features from other Dataset with a different number of " "rows"); } + if (other->has_raw_ != has_raw_) { + Log::Fatal("Can only add features from other Dataset if both or neither have raw data."); + } int mv_gid = -1; int other_mv_gid = -1; for (int i = 0; i < num_groups_; ++i) { @@ -1393,6 +1437,20 @@ void Dataset::AddFeaturesFrom(Dataset* other) { PushClearIfEmpty(&max_bin_by_feature_, num_total_features_, other->max_bin_by_feature_, other->num_total_features_, -1); num_total_features_ += other->num_total_features_; + for (size_t i = 0; i < (other->numeric_feature_map_).size(); ++i) { + int feat_ind = numeric_feature_map_[i]; + if (feat_ind > -1) { + numeric_feature_map_.push_back(feat_ind + num_numeric_features_); + } else { + numeric_feature_map_.push_back(-1); + } + } + num_numeric_features_ += other->num_numeric_features_; + if (has_raw_) { + for (int i = 0; i < other->num_numeric_features_; ++i) { + raw_data_.push_back(other->raw_data_[i]); + } + } } } // namespace LightGBM diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index ba707329f7d2..24c9637bfa56 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -22,6 +22,10 @@ DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& pre weight_idx_ = NO_SPECIFIC; group_idx_ = NO_SPECIFIC; SetHeader(filename); + store_raw_ = false; + if (io_config.linear_tree) { + store_raw_ = true; + } } DatasetLoader::~DatasetLoader() { @@ -183,6 +187,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac } } auto dataset = std::unique_ptr(new Dataset()); + if (store_raw_) { + dataset->SetHasRaw(true); + } data_size_t num_global_data = 0; std::vector used_data_indices; auto bin_filename = CheckCanLoadFromBin(filename); @@ -205,6 +212,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac static_cast(dataset->num_data_)); // construct feature bin mappers ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get()); + if (dataset->has_raw()) { + dataset->ResizeRaw(dataset->num_data_); + } // initialize label dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_); // extract features @@ -222,6 +232,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac static_cast(dataset->num_data_)); // construct feature bin mappers ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get()); + if (dataset->has_raw()) { + dataset->ResizeRaw(dataset->num_data_); + } // initialize label dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_); Log::Debug("Making second pass..."); @@ -248,6 +261,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, data_size_t num_global_data = 0; std::vector used_data_indices; auto dataset = std::unique_ptr(new Dataset()); + if (store_raw_) { + dataset->SetHasRaw(true); + } auto bin_filename = CheckCanLoadFromBin(filename); if (bin_filename.size() == 0) { auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_)); @@ -264,6 +280,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, // initialize label dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_); dataset->CreateValid(train_data); + if (dataset->has_raw()) { + dataset->ResizeRaw(dataset->num_data_); + } // extract features ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get()); text_data.clear(); @@ -275,6 +294,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, // initialize label dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_); dataset->CreateValid(train_data); + if (dataset->has_raw()) { + dataset->ResizeRaw(dataset->num_data_); + } // extract features ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get()); } @@ -356,6 +378,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_)); dataset->zero_as_missing_ = *(reinterpret_cast(mem_ptr)); mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_)); + dataset->has_raw_ = *(reinterpret_cast(mem_ptr)); + mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_)); const int* tmp_feature_map = reinterpret_cast(mem_ptr); dataset->used_feature_map_.clear(); for (int i = 0; i < dataset->num_total_features_; ++i) { @@ -550,6 +574,43 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b *used_data_indices, i))); } dataset->feature_groups_.shrink_to_fit(); + + // raw data + dataset->numeric_feature_map_ = std::vector(dataset->num_features_, false); + dataset->num_numeric_features_ = 0; + for (int i = 0; i < dataset->num_features_; ++i) { + if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) { + dataset->numeric_feature_map_[i] = -1; + } else { + dataset->numeric_feature_map_[i] = dataset->num_numeric_features_; + ++dataset->num_numeric_features_; + } + } + if (dataset->has_raw()) { + dataset->ResizeRaw(dataset->num_data()); + size_t row_size = dataset->num_numeric_features_ * sizeof(float); + if (row_size > buffer_size) { + buffer_size = row_size; + buffer.resize(buffer_size); + } + for (int i = 0; i < dataset->num_data(); ++i) { + read_cnt = reader->Read(buffer.data(), row_size); + if (read_cnt != row_size) { + Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt); + } + mem_ptr = buffer.data(); + const float* tmp_ptr_raw_row = reinterpret_cast(mem_ptr); + std::vector curr_row(dataset->num_numeric_features_, 0); + for (int j = 0; j < dataset->num_features(); ++j) { + int feat_ind = dataset->numeric_feature_map_[j]; + if (feat_ind >= 0) { + dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind]; + } + } + mem_ptr += row_size; + } + } + dataset->is_finish_load_ = true; return dataset.release(); } @@ -704,6 +765,9 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values, } auto dataset = std::unique_ptr(new Dataset(num_data)); dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_); + if (dataset->has_raw()) { + dataset->ResizeRaw(num_data); + } dataset->set_feature_names(feature_names_); return dataset.release(); } @@ -1061,6 +1125,9 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr(&sample_indices).data(), Common::Vector2Ptr(&sample_values).data(), Common::VectorSize(sample_indices).data(), static_cast(sample_indices.size()), sample_data.size(), config_); + if (dataset->has_raw()) { + dataset->ResizeRaw(sample_data.size()); + } } /*! \brief Extract local features from memory */ @@ -1068,10 +1135,11 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat std::vector> oneline_features; double tmp_label = 0.0f; auto& ref_text_data = *text_data; + std::vector feature_row(dataset->num_features_); if (predict_fun_ == nullptr) { OMP_INIT_EX(); // if doesn't need to prediction with initial model - #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) + #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row) for (data_size_t i = 0; i < dataset->num_data_; ++i) { OMP_LOOP_EX_BEGIN(); const int tid = omp_get_thread_num(); @@ -1095,6 +1163,9 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat int group = dataset->feature2group_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx]; dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second); + if (dataset->has_raw()) { + feature_row[feature_idx] = inner_data.second; + } } else { if (inner_data.first == weight_idx_) { dataset->metadata_.SetWeightAt(i, static_cast(inner_data.second)); @@ -1103,6 +1174,14 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat } } } + if (dataset->has_raw()) { + for (size_t j = 0; j < feature_row.size(); ++j) { + int feat_ind = dataset->numeric_feature_map_[j]; + if (feat_ind >= 0) { + dataset->raw_data_[feat_ind][i] = feature_row[j]; + } + } + } dataset->FinishOneRow(tid, i, is_feature_added); OMP_LOOP_EX_END(); } @@ -1111,7 +1190,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat OMP_INIT_EX(); // if need to prediction with initial model std::vector init_score(dataset->num_data_ * num_class_); - #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) + #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row) for (data_size_t i = 0; i < dataset->num_data_; ++i) { OMP_LOOP_EX_BEGIN(); const int tid = omp_get_thread_num(); @@ -1141,6 +1220,9 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat int group = dataset->feature2group_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx]; dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second); + if (dataset->has_raw()) { + feature_row[feature_idx] = inner_data.second; + } } else { if (inner_data.first == weight_idx_) { dataset->metadata_.SetWeightAt(i, static_cast(inner_data.second)); @@ -1150,6 +1232,14 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector* text_dat } } dataset->FinishOneRow(tid, i, is_feature_added); + if (dataset->has_raw()) { + for (size_t j = 0; j < feature_row.size(); ++j) { + int feat_ind = dataset->numeric_feature_map_[j]; + if (feat_ind >= 0) { + dataset->raw_data_[feat_ind][i] = feature_row[j]; + } + } + } OMP_LOOP_EX_END(); } OMP_THROW_EX(); @@ -1173,8 +1263,9 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* (data_size_t start_idx, const std::vector& lines) { std::vector> oneline_features; double tmp_label = 0.0f; + std::vector feature_row(dataset->num_features_); OMP_INIT_EX(); - #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) + #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row) for (data_size_t i = 0; i < static_cast(lines.size()); ++i) { OMP_LOOP_EX_BEGIN(); const int tid = omp_get_thread_num(); @@ -1202,6 +1293,9 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* int group = dataset->feature2group_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx]; dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second); + if (dataset->has_raw()) { + feature_row[feature_idx] = inner_data.second; + } } else { if (inner_data.first == weight_idx_) { dataset->metadata_.SetWeightAt(start_idx + i, static_cast(inner_data.second)); @@ -1210,6 +1304,14 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* } } } + if (dataset->has_raw()) { + for (size_t j = 0; j < feature_row.size(); ++j) { + int feat_ind = dataset->numeric_feature_map_[j]; + if (feat_ind >= 0) { + dataset->raw_data_[feat_ind][i] = feature_row[j]; + } + } + } dataset->FinishOneRow(tid, i, is_feature_added); OMP_LOOP_EX_END(); } diff --git a/src/io/tree.cpp b/src/io/tree.cpp index a4e42831b0e7..91eb0a0ab116 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -14,7 +14,7 @@ namespace LightGBM { -Tree::Tree(int max_leaves, bool track_branch_features) +Tree::Tree(int max_leaves, bool track_branch_features, bool is_linear) :max_leaves_(max_leaves), track_branch_features_(track_branch_features) { left_child_.resize(max_leaves_ - 1); right_child_.resize(max_leaves_ - 1); @@ -46,6 +46,13 @@ Tree::Tree(int max_leaves, bool track_branch_features) cat_boundaries_.push_back(0); cat_boundaries_inner_.push_back(0); max_depth_ = -1; + is_linear_ = is_linear; + if (is_linear_) { + leaf_coeff_.resize(max_leaves_); + leaf_const_ = std::vector(max_leaves_, 0); + leaf_features_.resize(max_leaves_); + leaf_features_inner_.resize(max_leaves_); + } } int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, @@ -103,8 +110,42 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32 score[(data_idx)] += static_cast(leaf_value_[~node]); \ }\ + +#define PredictionFunLinear(niter, fidx_in_iter, start_pos, decision_fun, \ + iter_idx, data_idx) \ + std::vector> iter((niter)); \ + for (int i = 0; i < (niter); ++i) { \ + iter[i].reset(data->FeatureIterator((fidx_in_iter))); \ + iter[i]->Reset((start_pos)); \ + } \ + for (data_size_t i = start; i < end; ++i) { \ + int node = 0; \ + while (node >= 0) { \ + node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, \ + default_bins[node], max_bins[node]); \ + } \ + double add_score = leaf_const_[~node]; \ + bool nan_found = false; \ + const double* coeff_ptr = leaf_coeff_[~node].data(); \ + const float** data_ptr = feat_ptr[~node].data(); \ + for (size_t j = 0; j < leaf_features_inner_[~node].size(); ++j) { \ + float feat_val = data_ptr[j][(data_idx)]; \ + if (std::isnan(feat_val)) { \ + nan_found = true; \ + break; \ + } \ + add_score += coeff_ptr[j] * feat_val; \ + } \ + if (nan_found) { \ + score[(data_idx)] += leaf_value_[~node]; \ + } else { \ + score[(data_idx)] += add_score; \ + } \ +}\ + + void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const { - if (num_leaves_ <= 1) { + if (!is_linear_ && num_leaves_ <= 1) { if (leaf_value_[0] != 0.0f) { #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) for (data_size_t i = 0; i < num_data; ++i) { @@ -121,37 +162,71 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl default_bins[i] = bin_mapper->GetDefaultBin(); max_bins[i] = bin_mapper->num_bin() - 1; } - if (num_cat_ > 0) { - if (data->num_features() > num_leaves_ - 1) { - Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i); - }); + if (is_linear_) { + std::vector> feat_ptr(num_leaves_); + for (int leaf_num = 0; leaf_num < num_leaves_; ++leaf_num) { + for (int feat : leaf_features_inner_[leaf_num]) { + feat_ptr[leaf_num].push_back(data->raw_index(feat)); + } + } + if (num_cat_ > 0) { + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i); + }); + } } else { - Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i); - }); + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i); + }); + } } } else { - if (data->num_features() > num_leaves_ - 1) { - Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i); - }); + if (num_cat_ > 0) { + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i); + }); + } } else { - Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i); - }); + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i); + }); + } } } } void Tree::AddPredictionToScore(const Dataset* data, - const data_size_t* used_data_indices, - data_size_t num_data, double* score) const { - if (num_leaves_ <= 1) { + const data_size_t* used_data_indices, + data_size_t num_data, double* score) const { + if (!is_linear_ && num_leaves_ <= 1) { if (leaf_value_[0] != 0.0f) { #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) for (data_size_t i = 0; i < num_data; ++i) { @@ -168,34 +243,72 @@ void Tree::AddPredictionToScore(const Dataset* data, default_bins[i] = bin_mapper->GetDefaultBin(); max_bins[i] = bin_mapper->num_bin() - 1; } - if (num_cat_ > 0) { - if (data->num_features() > num_leaves_ - 1) { - Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]); - }); + if (is_linear_) { + std::vector> feat_ptr(num_leaves_); + for (int leaf_num = 0; leaf_num < num_leaves_; ++leaf_num) { + for (int feat : leaf_features_inner_[leaf_num]) { + feat_ptr[leaf_num].push_back(data->raw_index(feat)); + } + } + if (num_cat_ > 0) { + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, + node, used_data_indices[i]); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]); + }); + } } else { - Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]); - }); + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, + node, used_data_indices[i]); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr] + (int, data_size_t start, data_size_t end) { + PredictionFunLinear(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, + split_feature_inner_[node], used_data_indices[i]); + }); + } } } else { - if (data->num_features() > num_leaves_ - 1) { - Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]); - }); + if (num_cat_ > 0) { + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]); + }); + } } else { - Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] - (int, data_size_t start, data_size_t end) { - PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]); - }); + if (data->num_features() > num_leaves_ - 1) { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]); + }); + } else { + Threading::For(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins] + (int, data_size_t start, data_size_t end) { + PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]); + }); + } } } } #undef PredictionFun +#undef PredictionFunLinear double Tree::GetUpperBoundValue() const { double upper_bound = leaf_value_[0]; @@ -259,8 +372,37 @@ std::string Tree::ToString() const { str_buf << "cat_threshold=" << ArrayToString(cat_threshold_, cat_threshold_.size()) << '\n'; } + str_buf << "is_linear=" << is_linear_ << '\n'; + + if (is_linear_) { + str_buf << "leaf_const=" + << ArrayToString(leaf_const_, num_leaves_) << '\n'; + std::vector num_feat(num_leaves_); + for (int i = 0; i < num_leaves_; ++i) { + num_feat[i] = leaf_coeff_[i].size(); + } + str_buf << "num_features=" + << ArrayToString(num_feat, num_leaves_) << '\n'; + str_buf << "leaf_features="; + for (int i = 0; i < num_leaves_; ++i) { + if (num_feat[i] > 0) { + str_buf << ArrayToString(leaf_features_[i], leaf_features_[i].size()) << ' '; + } + str_buf << ' '; + } + str_buf << '\n'; + str_buf << "leaf_coeff="; + for (int i = 0; i < num_leaves_; ++i) { + if (num_feat[i] > 0) { + str_buf << ArrayToString(leaf_coeff_[i], leaf_coeff_[i].size()) << ' '; + } + str_buf << ' '; + } + str_buf << '\n'; + } str_buf << "shrinkage=" << shrinkage_ << '\n'; str_buf << '\n'; + return str_buf.str(); } @@ -508,7 +650,7 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const { Tree::Tree(const char* str, size_t* used_len) { auto p = str; std::unordered_map key_vals; - const int max_num_line = 17; + const int max_num_line = 22; int read_line = 0; while (read_line < max_num_line) { if (*p == '\r' || *p == '\n') break; @@ -549,7 +691,13 @@ Tree::Tree(const char* str, size_t* used_len) { shrinkage_ = 1.0f; } - if (num_leaves_ <= 1) { return; } + if (key_vals.count("is_linear")) { + int is_linear_int; + Common::Atoi(key_vals["is_linear"].c_str(), &is_linear_int); + is_linear_ = static_cast(is_linear_int); + } + + if ((num_leaves_ <= 1) && !is_linear_) { return; } if (key_vals.count("left_child")) { left_child_ = CommonC::StringToArrayFast(key_vals["left_child"], num_leaves_ - 1); @@ -617,6 +765,45 @@ Tree::Tree(const char* str, size_t* used_len) { decision_type_ = std::vector(num_leaves_ - 1, 0); } + if (is_linear_) { + if (key_vals.count("leaf_const")) { + leaf_const_ = Common::StringToArrayFast(key_vals["leaf_const"], num_leaves_); + } else { + leaf_const_.resize(num_leaves_); + } + std::vector num_feat; + if (key_vals.count("num_features")) { + num_feat = Common::StringToArrayFast(key_vals["num_features"], num_leaves_); + } + leaf_coeff_.resize(num_leaves_); + leaf_features_.resize(num_leaves_); + leaf_features_inner_.resize(num_leaves_); + if (num_feat.size() > 0) { + int total_num_feat = 0; + for (size_t i = 0; i < num_feat.size(); ++i) { total_num_feat += num_feat[i]; } + std::vector all_leaf_features; + if (key_vals.count("leaf_features")) { + all_leaf_features = Common::StringToArrayFast(key_vals["leaf_features"], total_num_feat); + } + std::vector all_leaf_coeff; + if (key_vals.count("leaf_coeff")) { + all_leaf_coeff = Common::StringToArrayFast(key_vals["leaf_coeff"], total_num_feat); + } + int sum_num_feat = 0; + for (int i = 0; i < num_leaves_; ++i) { + if (num_feat[i] > 0) { + if (key_vals.count("leaf_features")) { + leaf_features_[i].assign(all_leaf_features.begin() + sum_num_feat, all_leaf_features.begin() + sum_num_feat + num_feat[i]); + } + if (key_vals.count("leaf_coeff")) { + leaf_coeff_[i].assign(all_leaf_coeff.begin() + sum_num_feat, all_leaf_coeff.begin() + sum_num_feat + num_feat[i]); + } + } + sum_num_feat += num_feat[i]; + } + } + } + if (num_cat_ > 0) { if (key_vals.count("cat_boundaries")) { cat_boundaries_ = CommonC::StringToArrayFast(key_vals["cat_boundaries"], num_cat_ + 1); diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp index df90aafb945c..efe36072544b 100644 --- a/src/treelearner/gpu_tree_learner.cpp +++ b/src/treelearner/gpu_tree_learner.cpp @@ -734,8 +734,8 @@ void GPUTreeLearner::InitGPU(int platform_id, int device_id) { SetupKernelArguments(); } -Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians) { - return SerialTreeLearner::Train(gradients, hessians); +Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { + return SerialTreeLearner::Train(gradients, hessians, is_first_tree); } void GPUTreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) { diff --git a/src/treelearner/gpu_tree_learner.h b/src/treelearner/gpu_tree_learner.h index a909c57cbadc..91980ce934ea 100644 --- a/src/treelearner/gpu_tree_learner.h +++ b/src/treelearner/gpu_tree_learner.h @@ -48,7 +48,7 @@ class GPUTreeLearner: public SerialTreeLearner { void Init(const Dataset* train_data, bool is_constant_hessian) override; void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override; void ResetIsConstantHessian(bool is_constant_hessian) override; - Tree* Train(const score_t* gradients, const score_t *hessians) override; + Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override; void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override { SerialTreeLearner::SetBaggingData(subset, used_indices, num_data); diff --git a/src/treelearner/linear_tree_learner.cpp b/src/treelearner/linear_tree_learner.cpp new file mode 100644 index 000000000000..58482b8eef9b --- /dev/null +++ b/src/treelearner/linear_tree_learner.cpp @@ -0,0 +1,390 @@ +/*! + * Copyright (c) 2016 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#include "linear_tree_learner.h" + +#include + +#ifndef LGB_R_BUILD +// preprocessor definition ensures we use only MPL2-licensed code +#define EIGEN_MPL2_ONLY +#include +#endif // !LGB_R_BUILD + +namespace LightGBM { + +void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { + SerialTreeLearner::Init(train_data, is_constant_hessian); + LinearTreeLearner::InitLinear(train_data, config_->num_leaves); +} + +void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { + leaf_map_ = std::vector(train_data->num_data(), -1); + contains_nan_ = std::vector(train_data->num_features(), 0); + // identify features containing nans +#pragma omp parallel for schedule(static) + for (int feat = 0; feat < train_data->num_features(); ++feat) { + auto bin_mapper = train_data_->FeatureBinMapper(feat); + if (bin_mapper->bin_type() == BinType::NumericalBin) { + const float* feat_ptr = train_data_->raw_index(feat); + for (int i = 0; i < train_data->num_data(); ++i) { + if (std::isnan(feat_ptr[i])) { + contains_nan_[feat] = 1; + break; + } + } + } + } + for (int feat = 0; feat < train_data->num_features(); ++feat) { + if (contains_nan_[feat]) { + any_nan_ = true; + break; + } + } + // preallocate the matrix used to calculate linear model coefficients + int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); + XTHX_.clear(); + XTg_.clear(); + for (int i = 0; i < max_leaves; ++i) { + // store only upper triangular half of matrix as an array, in row-major order + // this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression) + // we add another 8 to ensure cache lines are not shared among processors + XTHX_.push_back(std::vector((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0)); + XTg_.push_back(std::vector(max_num_feat + 9, 0.0)); + } + XTHX_by_thread_.clear(); + XTg_by_thread_.clear(); + int max_threads = omp_get_max_threads(); + for (int i = 0; i < max_threads; ++i) { + XTHX_by_thread_.push_back(XTHX_); + XTg_by_thread_.push_back(XTg_); + } +} + +Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { + Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); + gradients_ = gradients; + hessians_ = hessians; + int num_threads = OMP_NUM_THREADS(); + if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { + Log::Warning( + "Detected that num_threads changed during training (from %d to %d), " + "it may cause unexpected errors.", + share_state_->num_threads, num_threads); + } + share_state_->num_threads = num_threads; + + // some initial works before training + BeforeTrain(); + + auto tree = std::unique_ptr(new Tree(config_->num_leaves, true, true)); + auto tree_ptr = tree.get(); + constraints_->ShareTreePointer(tree_ptr); + + // root leaf + int left_leaf = 0; + int cur_depth = 1; + // only root leaf can be splitted on first time + int right_leaf = -1; + + int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); + + for (int split = init_splits; split < config_->num_leaves - 1; ++split) { + // some initial works before finding best split + if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { + // find best threshold for every feature + FindBestSplits(tree_ptr); + } + // Get a leaf with max split gain + int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_)); + // Get split information for best leaf + const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; + // cannot split, quit + if (best_leaf_SplitInfo.gain <= 0.0) { + Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); + break; + } + // split tree with best leaf + Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); + cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); + } + + bool has_nan = false; + if (any_nan_) { + for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { + if (contains_nan_[tree_ptr->split_feature_inner(i)]) { + has_nan = true; + break; + } + } + } + + GetLeafMap(tree_ptr); + + if (has_nan) { + CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + } else { + CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + } + + Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth); + return tree.release(); +} + +Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { + auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); + bool has_nan = false; + if (any_nan_) { + for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { + if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + has_nan = true; + break; + } + } + } + GetLeafMap(tree); + if (has_nan) { + CalculateLinear(tree, true, gradients, hessians, false); + } else { + CalculateLinear(tree, true, gradients, hessians, false); + } + return tree; +} + +Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const { + data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); + return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); +} + +void LinearTreeLearner::GetLeafMap(Tree* tree) const { + std::fill(leaf_map_.begin(), leaf_map_.end(), -1); + // map data to leaf number + const data_size_t* ind = data_partition_->indices(); +#pragma omp parallel for schedule(dynamic) + for (int i = 0; i < tree->num_leaves(); ++i) { + data_size_t idx = data_partition_->leaf_begin(i); + for (int j = 0; j < data_partition_->leaf_count(i); ++j) { + leaf_map_[ind[idx + j]] = i; + } + } +} + +#ifdef LGB_R_BUILD +template +void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { + Log::Fatal("Linear tree learner does not work with R package."); +} +#else +template +void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { + tree->SetIsLinear(true); + int num_leaves = tree->num_leaves(); + int num_threads = OMP_NUM_THREADS(); + if (is_first_tree) { + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num)); + } + return; + } + + // calculate coefficients using the method described in Eq 3 of https://arxiv.org/pdf/1802.05640.pdf + // the coefficients vector is given by + // - (X_T * H * X + lambda) ^ (-1) * (X_T * g) + // where: + // X is the matrix where the first column is the feature values and the second is all ones, + // H is the diagonal matrix of the hessian, + // lambda is the diagonal matrix with diagonal entries equal to the regularisation term linear_lambda + // g is the vector of gradients + // the subscript _T denotes the transpose + + // create array of pointers to raw data, and coefficient matrices, for each leaf + std::vector> leaf_features; + std::vector leaf_num_features; + std::vector> raw_data_ptr; + int max_num_features = 0; + for (int i = 0; i < num_leaves; ++i) { + std::vector raw_features; + if (is_refit) { + raw_features = tree->LeafFeatures(i); + } else { + raw_features = tree->branch_features(i); + } + std::sort(raw_features.begin(), raw_features.end()); + auto new_end = std::unique(raw_features.begin(), raw_features.end()); + raw_features.erase(new_end, raw_features.end()); + std::vector numerical_features; + std::vector data_ptr; + for (size_t j = 0; j < raw_features.size(); ++j) { + int feat = train_data_->InnerFeatureIndex(raw_features[j]); + auto bin_mapper = train_data_->FeatureBinMapper(feat); + if (bin_mapper->bin_type() == BinType::NumericalBin) { + numerical_features.push_back(feat); + data_ptr.push_back(train_data_->raw_index(feat)); + } + } + leaf_features.push_back(numerical_features); + raw_data_ptr.push_back(data_ptr); + leaf_num_features.push_back(numerical_features.size()); + if (static_cast(numerical_features.size()) > max_num_features) { + max_num_features = numerical_features.size(); + } + } + // clear the coefficient matrices +#pragma omp parallel for schedule(static) + for (int i = 0; i < num_threads; ++i) { + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + int num_feat = leaf_features[leaf_num].size(); + std::fill(XTHX_by_thread_[i][leaf_num].begin(), XTHX_by_thread_[i][leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0); + std::fill(XTg_by_thread_[i][leaf_num].begin(), XTg_by_thread_[i][leaf_num].begin() + num_feat + 1, 0); + } + } +#pragma omp parallel for schedule(static) + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + int num_feat = leaf_features[leaf_num].size(); + std::fill(XTHX_[leaf_num].begin(), XTHX_[leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0); + std::fill(XTg_[leaf_num].begin(), XTg_[leaf_num].begin() + num_feat + 1, 0); + } + std::vector> num_nonzero; + for (int i = 0; i < num_threads; ++i) { + if (HAS_NAN) { + num_nonzero.push_back(std::vector(num_leaves, 0)); + } + } + OMP_INIT_EX(); +#pragma omp parallel if (num_data_ > 1024) + { + std::vector curr_row(max_num_features + 1); + int tid = omp_get_thread_num(); +#pragma omp for schedule(static) + for (int i = 0; i < num_data_; ++i) { + OMP_LOOP_EX_BEGIN(); + int leaf_num = leaf_map_[i]; + if (leaf_num < 0) { + continue; + } + bool nan_found = false; + int num_feat = leaf_num_features[leaf_num]; + for (int feat = 0; feat < num_feat; ++feat) { + if (HAS_NAN) { + float val = raw_data_ptr[leaf_num][feat][i]; + if (std::isnan(val)) { + nan_found = true; + break; + } + num_nonzero[tid][leaf_num] += 1; + curr_row[feat] = val; + } else { + curr_row[feat] = raw_data_ptr[leaf_num][feat][i]; + } + } + if (HAS_NAN) { + if (nan_found) { + continue; + } + } + curr_row[num_feat] = 1.0; + double h = hessians[i]; + double g = gradients[i]; + int j = 0; + for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) { + double f1_val = curr_row[feat1]; + XTg_by_thread_[tid][leaf_num][feat1] += f1_val * g; + f1_val *= h; + for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) { + XTHX_by_thread_[tid][leaf_num][j] += f1_val * curr_row[feat2]; + ++j; + } + } + OMP_LOOP_EX_END(); + } + } + OMP_THROW_EX(); + auto total_nonzero = std::vector(tree->num_leaves()); + // aggregate results from different threads + for (int tid = 0; tid < num_threads; ++tid) { +#pragma omp parallel for schedule(static) + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + int num_feat = leaf_features[leaf_num].size(); + for (int j = 0; j < (num_feat + 1) * (num_feat + 2) / 2; ++j) { + XTHX_[leaf_num][j] += XTHX_by_thread_[tid][leaf_num][j]; + } + for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) { + XTg_[leaf_num][feat1] += XTg_by_thread_[tid][leaf_num][feat1]; + } + if (HAS_NAN) { + total_nonzero[leaf_num] += num_nonzero[tid][leaf_num]; + } + } + } + if (!HAS_NAN) { + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); + } + } + double shrinkage = tree->shrinkage(); + double decay_rate = config_->refit_decay_rate; + // copy into eigen matrices and solve +#pragma omp parallel for schedule(static) + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + if (total_nonzero[leaf_num] < static_cast(leaf_features[leaf_num].size()) + 1) { + if (is_refit) { + double old_const = tree->LeafConst(leaf_num); + tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * tree->LeafOutput(leaf_num) * shrinkage); + tree->SetLeafCoeffs(leaf_num, std::vector(leaf_features[leaf_num].size(), 0)); + tree->SetLeafFeaturesInner(leaf_num, leaf_features[leaf_num]); + } else { + tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num)); + } + continue; + } + int num_feat = leaf_features[leaf_num].size(); + Eigen::MatrixXd XTHX_mat(num_feat + 1, num_feat + 1); + Eigen::MatrixXd XTg_mat(num_feat + 1, 1); + int j = 0; + for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) { + for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) { + XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; + XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); + if ((feat1 == feat2) && (feat1 < num_feat)) { + XTHX_mat(feat1, feat2) += config_->linear_lambda; + } + ++j; + } + XTg_mat(feat1) = XTg_[leaf_num][feat1]; + } + Eigen::MatrixXd coeffs = - XTHX_mat.fullPivLu().inverse() * XTg_mat; + std::vector coeffs_vec; + std::vector features_new; + std::vector old_coeffs = tree->LeafCoeffs(leaf_num); + for (size_t i = 0; i < leaf_features[leaf_num].size(); ++i) { + if (is_refit) { + features_new.push_back(leaf_features[leaf_num][i]); + coeffs_vec.push_back(decay_rate * old_coeffs[i] + (1.0 - decay_rate) * coeffs(i) * shrinkage); + } else { + if (coeffs(i) < -kZeroThreshold || coeffs(i) > kZeroThreshold) { + coeffs_vec.push_back(coeffs(i)); + int feat = leaf_features[leaf_num][i]; + features_new.push_back(feat); + } + } + } + // update the tree properties + tree->SetLeafFeaturesInner(leaf_num, features_new); + std::vector features_raw(features_new.size()); + for (size_t i = 0; i < features_new.size(); ++i) { + features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); + } + tree->SetLeafFeatures(leaf_num, features_raw); + tree->SetLeafCoeffs(leaf_num, coeffs_vec); + if (is_refit) { + double old_const = tree->LeafConst(leaf_num); + tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * coeffs(num_feat) * shrinkage); + } else { + tree->SetLeafConst(leaf_num, coeffs(num_feat)); + } + } +} +#endif // LGB_R_BUILD +} // namespace LightGBM diff --git a/src/treelearner/linear_tree_learner.h b/src/treelearner/linear_tree_learner.h new file mode 100644 index 000000000000..88bc7f91ede1 --- /dev/null +++ b/src/treelearner/linear_tree_learner.h @@ -0,0 +1,128 @@ +/*! + * Copyright (c) 2016 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef LIGHTGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_ +#define LIGHTGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_ + +#include +#include +#include +#include +#include +#include + +#include "serial_tree_learner.h" + +namespace LightGBM { + +class LinearTreeLearner: public SerialTreeLearner { + public: + explicit LinearTreeLearner(const Config* config) : SerialTreeLearner(config) {} + + void Init(const Dataset* train_data, bool is_constant_hessian) override; + + void InitLinear(const Dataset* train_data, const int max_leaves) override; + + Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override; + + /*! \brief Create array mapping dataset to leaf index, used for linear trees */ + void GetLeafMap(Tree* tree) const; + + template + void CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const; + + Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override; + + Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t* hessians) const override; + + void AddPredictionToScore(const Tree* tree, + double* out_score) const override { + CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); + bool has_nan = false; + if (any_nan_) { + for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { + // use split_feature because split_feature_inner doesn't work when refitting existing tree + if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + has_nan = true; + break; + } + } + } + if (has_nan) { + AddPredictionToScoreInner(tree, out_score); + } else { + AddPredictionToScoreInner(tree, out_score); + } + } + + template + void AddPredictionToScoreInner(const Tree* tree, double* out_score) const { + int num_leaves = tree->num_leaves(); + std::vector leaf_const(num_leaves); + std::vector> leaf_coeff(num_leaves); + std::vector> feat_ptr(num_leaves); + std::vector leaf_output(num_leaves); + std::vector leaf_num_features(num_leaves); + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + leaf_const[leaf_num] = tree->LeafConst(leaf_num); + leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); + leaf_output[leaf_num] = tree->LeafOutput(leaf_num); + for (int feat : tree->LeafFeaturesInner(leaf_num)) { + feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); + } + leaf_num_features[leaf_num] = feat_ptr[leaf_num].size(); + } + OMP_INIT_EX(); +#pragma omp parallel for schedule(static) if (num_data_ > 1024) + for (int i = 0; i < num_data_; ++i) { + OMP_LOOP_EX_BEGIN(); + int leaf_num = leaf_map_[i]; + if (leaf_num < 0) { + continue; + } + double output = leaf_const[leaf_num]; + int num_feat = leaf_num_features[leaf_num]; + if (HAS_NAN) { + bool nan_found = false; + for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) { + float val = feat_ptr[leaf_num][feat_ind][i]; + if (std::isnan(val)) { + nan_found = true; + break; + } + output += val * leaf_coeff[leaf_num][feat_ind]; + } + if (nan_found) { + out_score[i] += leaf_output[leaf_num]; + } else { + out_score[i] += output; + } + } else { + for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) { + output += feat_ptr[leaf_num][feat_ind][i] * leaf_coeff[leaf_num][feat_ind]; + } + out_score[i] += output; + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + } + + protected: + /*! \brief whether numerical features contain any nan values */ + std::vector contains_nan_; + /*! whether any numerical feature contains a nan value */ + bool any_nan_; + /*! \brief map dataset to leaves */ + mutable std::vector leaf_map_; + /*! \brief temporary storage for calculating linear model coefficients */ + mutable std::vector> XTHX_; + mutable std::vector> XTg_; + mutable std::vector>> XTHX_by_thread_; + mutable std::vector>> XTg_by_thread_; +}; + +} // namespace LightGBM +#endif // LightGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_ diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index cfa7d9b38b99..2fd7fcba5859 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -155,7 +155,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) { constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves, train_data_->num_features())); } -Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) { +Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool /*is_first_tree*/) { Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); gradients_ = gradients; hessians_ = hessians; @@ -172,7 +172,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians BeforeTrain(); bool track_branch_features = !(config_->interaction_constraints_vector.empty()); - auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features)); + auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features, false)); auto tree_ptr = tree.get(); constraints_->ShareTreePointer(tree_ptr); @@ -203,6 +203,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); } + Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth); return tree.release(); } @@ -242,7 +243,8 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* return tree.release(); } -Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, const score_t* gradients, const score_t *hessians) { +Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const { data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); return FitByExistingTree(old_tree, gradients, hessians); } diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index fd06c235a8bf..1a0bda53b5db 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -39,6 +39,7 @@ using json11::Json; /*! \brief forward declaration */ class CostEfficientGradientBoosting; + /*! * \brief Used for learning a tree by single machine */ @@ -74,12 +75,12 @@ class SerialTreeLearner: public TreeLearner { } } - Tree* Train(const score_t* gradients, const score_t *hessians) override; + Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override; Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override; Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, - const score_t* gradients, const score_t* hessians) override; + const score_t* gradients, const score_t* hessians) const override; void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override { if (subset == nullptr) { @@ -96,10 +97,10 @@ class SerialTreeLearner: public TreeLearner { void AddPredictionToScore(const Tree* tree, double* out_score) const override { + CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); if (tree->num_leaves() <= 1) { return; } - CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); #pragma omp parallel for schedule(static, 1) for (int i = 0; i < tree->num_leaves(); ++i) { double output = static_cast(tree->LeafOutput(i)); @@ -189,7 +190,6 @@ class SerialTreeLearner: public TreeLearner { FeatureHistogram* smaller_leaf_histogram_array_; /*! \brief pointer to histograms array of larger leaf */ FeatureHistogram* larger_leaf_histogram_array_; - /*! \brief store best split points for all leaves */ std::vector best_split_per_leaf_; /*! \brief store best split per feature for all leaves */ diff --git a/src/treelearner/tree_learner.cpp b/src/treelearner/tree_learner.cpp index ab009a0b1008..04184932a335 100644 --- a/src/treelearner/tree_learner.cpp +++ b/src/treelearner/tree_learner.cpp @@ -8,13 +8,22 @@ #include "gpu_tree_learner.h" #include "parallel_tree_learner.h" #include "serial_tree_learner.h" +#include "linear_tree_learner.h" namespace LightGBM { -TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) { +TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, + const Config* config) { if (device_type == std::string("cpu")) { if (learner_type == std::string("serial")) { - return new SerialTreeLearner(config); + if (config->linear_tree) { +#ifdef LGB_R_BUILD + Log::Fatal("Linear tree learner does not work with R package."); +#endif // LGB_R_BUILD + return new LinearTreeLearner(config); + } else { + return new SerialTreeLearner(config); + } } else if (learner_type == std::string("feature")) { return new FeatureParallelTreeLearner(config); } else if (learner_type == std::string("data")) { diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index b499ca7b1ea2..ed410e90088b 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -99,6 +99,38 @@ def test_chunked_dataset(self): train_data.construct() valid_data.construct() + def test_chunked_dataset_linear(self): + X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, + random_state=2) + chunk_size = X_train.shape[0] // 10 + 1 + X_train = [X_train[i * chunk_size:(i + 1) * chunk_size, :] for i in range(X_train.shape[0] // chunk_size + 1)] + X_test = [X_test[i * chunk_size:(i + 1) * chunk_size, :] for i in range(X_test.shape[0] // chunk_size + 1)] + params = {"bin_construct_sample_cnt": 100, 'linear_tree': True} + train_data = lgb.Dataset(X_train, label=y_train, params=params) + valid_data = train_data.create_valid(X_test, label=y_test, params=params) + train_data.construct() + valid_data.construct() + + def test_save_and_load_linear(self): + X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, + random_state=2) + X_train = np.concatenate([np.ones((X_train.shape[0], 1)), X_train], 1) + X_train[:X_train.shape[0] // 2, 0] = 0 + y_train[:X_train.shape[0] // 2] = 1 + params = {'linear_tree': True} + train_data = lgb.Dataset(X_train, label=y_train, params=params) + est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0]) + pred1 = est.predict(X_train) + train_data.save_binary('temp_dataset.bin') + train_data_2 = lgb.Dataset('temp_dataset.bin') + est = lgb.train(params, train_data_2, num_boost_round=10) + pred2 = est.predict(X_train) + np.testing.assert_allclose(pred1, pred2) + est.save_model('temp_model.txt') + est2 = lgb.Booster(model_file='temp_model.txt') + pred3 = est2.predict(X_train) + np.testing.assert_allclose(pred2, pred3) + def test_subset_group(self): X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train')) diff --git a/tests/python_package_test/test_consistency.py b/tests/python_package_test/test_consistency.py index 782dac4368e3..ce6223d13935 100644 --- a/tests/python_package_test/test_consistency.py +++ b/tests/python_package_test/test_consistency.py @@ -78,6 +78,18 @@ def test_binary(self): fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.file_load_check(lgb_train, '.train') + def test_binary_linear(self): + fd = FileLoader('../../examples/binary_classification', 'binary', 'train_linear.conf') + X_train, y_train, _ = fd.load_dataset('.train') + X_test, _, X_test_fn = fd.load_dataset('.test') + weight_train = fd.load_field('.train.weight') + lgb_train = lgb.Dataset(X_train, y_train, params=fd.params, weight=weight_train) + gbm = lgb.LGBMClassifier(**fd.params) + gbm.fit(X_train, y_train, sample_weight=weight_train) + sk_pred = gbm.predict_proba(X_test)[:, 1] + fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) + fd.file_load_check(lgb_train, '.train') + def test_multiclass(self): fd = FileLoader('../../examples/multiclass_classification', 'multiclass') X_train, y_train, _ = fd.load_dataset('.train') diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 84f5a8cc1071..faed1fce28b4 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2420,6 +2420,86 @@ def test_interaction_constraints(self): [1] + list(range(2, num_features))]), train_data, num_boost_round=10) + def test_linear(self): + # check that setting boosting=gbdt_linear fits better than boosting=gbdt when data has linear relationship + np.random.seed(0) + x = np.arange(0, 100, 0.1) + y = 2 * x + np.random.normal(0, 0.1, len(x)) + lgb_train = lgb.Dataset(x[:, np.newaxis], label=y) + params = {'verbose': -1, + 'metric': 'mse', + 'seed': 0, + 'num_leaves': 2} + est = lgb.train(params, lgb_train, num_boost_round=10) + pred1 = est.predict(x[:, np.newaxis]) + lgb_train = lgb.Dataset(x[:, np.newaxis], label=y) + res = {} + est = lgb.train(dict(params, linear_tree=True), lgb_train, num_boost_round=10, evals_result=res, + valid_sets=[lgb_train], valid_names=['train']) + pred2 = est.predict(x[:, np.newaxis]) + np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred2), atol=10**(-1)) + self.assertLess(mean_squared_error(y, pred2), mean_squared_error(y, pred1)) + # test again with nans in data + x[:10] = np.nan + lgb_train = lgb.Dataset(x[:, np.newaxis], label=y) + est = lgb.train(params, lgb_train, num_boost_round=10) + pred1 = est.predict(x[:, np.newaxis]) + lgb_train = lgb.Dataset(x[:, np.newaxis], label=y) + res = {} + est = lgb.train(dict(params, linear_tree=True), lgb_train, num_boost_round=10, evals_result=res, + valid_sets=[lgb_train], valid_names=['train']) + pred2 = est.predict(x[:, np.newaxis]) + np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred2), atol=10**(-1)) + self.assertLess(mean_squared_error(y, pred2), mean_squared_error(y, pred1)) + # test again with bagging + res = {} + est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train, + num_boost_round=10, evals_result=res, valid_sets=[lgb_train], valid_names=['train']) + pred = est.predict(x[:, np.newaxis]) + np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred), atol=10**(-1)) + # test with a feature that has only one non-nan value + x = np.concatenate([np.ones([x.shape[0], 1]), x[:, np.newaxis]], 1) + x[500:, 1] = np.nan + y[500:] += 10 + lgb_train = lgb.Dataset(x, label=y) + res = {} + est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train, + num_boost_round=10, evals_result=res, valid_sets=[lgb_train], valid_names=['train']) + pred = est.predict(x) + np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred), atol=10**(-1)) + # test with a categorical feature + x[:250, 0] = 0 + y[:250] += 10 + lgb_train = lgb.Dataset(x, label=y) + est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train, + num_boost_round=10, categorical_feature=[0]) + # test refit: same results on same data + est2 = est.refit(x, label=y) + p1 = est.predict(x) + p2 = est2.predict(x) + self.assertLess(np.mean(np.abs(p1 - p2)), 2) + # test refit with save and load + est.save_model('temp_model.txt') + est2 = lgb.Booster(model_file='temp_model.txt') + est2 = est2.refit(x, label=y) + p1 = est.predict(x) + p2 = est2.predict(x) + self.assertLess(np.mean(np.abs(p1 - p2)), 2) + # test refit: different results training on different data + est2 = est.refit(x[:100, :], label=y[:100]) + p3 = est2.predict(x) + self.assertGreater(np.mean(np.abs(p2 - p1)), np.abs(np.max(p3 - p1))) + # test when num_leaves - 1 < num_features and when num_leaves - 1 > num_features + X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2) + params = {'linear_tree': True, + 'verbose': -1, + 'metric': 'mse', + 'seed': 0} + train_data = lgb.Dataset(X_train, label=y_train, params=dict(params, num_leaves=2)) + est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0]) + train_data = lgb.Dataset(X_train, label=y_train, params=dict(params, num_leaves=60)) + est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0]) + def test_predict_with_start_iteration(self): def inner_test(X, y, params, early_stopping_rounds): X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) diff --git a/windows/LightGBM.vcxproj b/windows/LightGBM.vcxproj index 6b37ca53b728..5c01a8cdb932 100644 --- a/windows/LightGBM.vcxproj +++ b/windows/LightGBM.vcxproj @@ -112,6 +112,7 @@ Disabled MultiThreadedDebugDLL true + $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories) @@ -134,6 +135,7 @@ Disabled MultiThreadedDebugDLL true + $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories) @@ -153,6 +155,7 @@ Disabled MultiThreadedDebugDLL true + $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories) @@ -175,6 +178,7 @@ true MultiThreadedDLL true + $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories) @@ -201,6 +205,7 @@ true true true + $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories) @@ -221,6 +226,7 @@ true true true + $(ProjectDir)\..\eigen;%(AdditionalIncludeDirectories) @@ -287,6 +293,7 @@ + @@ -321,6 +328,7 @@ + @@ -328,4 +336,4 @@ - + \ No newline at end of file