diff --git a/.travis.yml b/.travis.yml index acbdac7ec825..d339fefe95cd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,6 +23,7 @@ env: - TASK=if-else - TASK=sdist PYTHON_VERSION=3.4 - TASK=bdist PYTHON_VERSION=3.5 + - TASK=proto - TASK=gpu METHOD=source - TASK=gpu METHOD=pip @@ -38,6 +39,8 @@ matrix: env: TASK=pylint - os: osx env: TASK=check-docs + - os: osx + env: TASK=proto before_install: - test -n $CC && unset CC diff --git a/.travis/test.sh b/.travis/test.sh index dc4fc2e35234..d52795d88c13 100644 --- a/.travis/test.sh +++ b/.travis/test.sh @@ -50,12 +50,24 @@ if [[ ${TASK} == "if-else" ]]; then conda create -q -n test-env python=$PYTHON_VERSION numpy source activate test-env mkdir build && cd build && cmake .. && make lightgbm || exit -1 - cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1 + cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf convert_model_language=cpp convert_model=../../src/boosting/gbdt_prediction.cpp && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1 cd $TRAVIS_BUILD_DIR/build && make lightgbm || exit -1 cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=predict.conf output_result=ifelse.pred && python test.py || exit -1 exit 0 fi +if [[ ${TASK} == "proto" ]]; then + conda create -q -n test-env python=$PYTHON_VERSION numpy + source activate test-env + mkdir build && cd build && cmake .. && make lightgbm || exit -1 + cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1 + cd $TRAVIS_BUILD_DIR && git clone https://github.com/google/protobuf && cd protobuf && ./autogen.sh && ./configure && make && sudo make install && sudo ldconfig + cd $TRAVIS_BUILD_DIR/build && rm -rf * && cmake -DUSE_PROTO=ON .. && make lightgbm || exit -1 + cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf model_format=proto && ../../lightgbm config=predict.conf output_result=proto.pred model_format=proto || exit -1 + cd $TRAVIS_BUILD_DIR/tests/cpp_test && python test.py || exit -1 + exit 0 +fi + conda create -q -n test-env python=$PYTHON_VERSION numpy nose scipy scikit-learn pandas matplotlib pytest source activate test-env diff --git a/CMakeLists.txt b/CMakeLists.txt index 30fcf4299307..a5db5d9e923e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,8 +124,25 @@ file(GLOB SOURCES src/treelearner/*.cpp ) -add_executable(lightgbm src/main.cpp ${SOURCES}) -add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES}) +if (USE_PROTO) + if(MSVC) + message(FATAL_ERROR "Cannot use proto with MSVC.") + endif(MSVC) + find_package(Protobuf REQUIRED) + PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS proto/model.proto) + include_directories(${PROTOBUF_INCLUDE_DIRS}) + include_directories(${CMAKE_CURRENT_BINARY_DIR}) + ADD_DEFINITIONS(-DUSE_PROTO) + SET(PROTO_FILES src/proto/gbdt_model_proto.cpp ${PROTO_HDRS} ${PROTO_SRCS}) +endif(USE_PROTO) + +add_executable(lightgbm src/main.cpp ${SOURCES} ${PROTO_FILES}) +add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES} ${PROTO_FILES}) + +if (USE_PROTO) + TARGET_LINK_LIBRARIES(lightgbm ${PROTOBUF_LIBRARIES}) + TARGET_LINK_LIBRARIES(_lightgbm ${PROTOBUF_LIBRARIES}) +endif(USE_PROTO) if(MSVC) set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm") diff --git a/docs/Installation-Guide.rst b/docs/Installation-Guide.rst index 67ab9702dc99..de081182cd70 100644 --- a/docs/Installation-Guide.rst +++ b/docs/Installation-Guide.rst @@ -271,6 +271,21 @@ Following procedure is for the MSVC (Microsoft Visual C++) build. **Note**: ``C:\local\boost_1_64_0\`` and ``C:\local\boost_1_64_0\lib64-msvc-14.0`` are locations of your Boost binaries. You also can set them to the environment variable to avoid ``Set ...`` commands when build. +Protobuf Support +^^^^^^^^^^^^^^^^ + +If you want to use protobuf to save and load models, install `protobuf c++ version `__ first. + +Then run cmake with USE_PROTO on, for example: + +.. code:: + + cmake -DUSE_PROTO=ON .. + +You can then use ``model_format=proto`` in parameters when save and load models. + +**Note**: for windows user, it's only tested with mingw. + Docker ^^^^^^ diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 36ba1cd3d85e..66c046fb5723 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -309,6 +309,20 @@ IO Parameters - file name of prediction result in ``prediction`` task +- ``model_format``, default=\ ``text``, type=string + + - format to save and load model. + + - ``text``, use text string. + + - ``proto``, use protocol buffer binary format. + + - save multiple formats by joining them with comma, like ``text,proto``, in this case, ``model_format`` will be add as suffix after ``output_model``. + + - not support loading with multiple formats. + + - Note: you need to cmake with -DUSE_PROTO=ON to use this parameter. + - ``is_pre_partition``, default=\ ``false``, type=bool - used for parallel learning (not include feature parallel) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 28185729a419..fca7f17b497d 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -4,6 +4,10 @@ #include #include +#ifdef USE_PROTO +#include "model.pb.h" +#endif // USE_PROTO + #include #include @@ -166,7 +170,7 @@ class LIGHTGBM_EXPORT Boosting { /*! * \brief Save model to file - * \param num_used_model Number of model that want to save, -1 means save all + * \param num_iterations Number of model that want to save, -1 means save all * \param is_finish Is training finished or not * \param filename Filename that want to save to * \return true if succeeded @@ -175,7 +179,7 @@ class LIGHTGBM_EXPORT Boosting { /*! * \brief Save model to string - * \param num_used_model Number of model that want to save, -1 means save all + * \param num_iterations Number of model that want to save, -1 means save all * \return Non-empty string if succeeded */ virtual std::string SaveModelToString(int num_iterations) const = 0; @@ -187,6 +191,22 @@ class LIGHTGBM_EXPORT Boosting { */ virtual bool LoadModelFromString(const std::string& model_str) = 0; + #ifdef USE_PROTO + /*! + * \brief Save model with protobuf + * \param num_iterations Number of model that want to save, -1 means save all + * \param filename Filename that want to save to + */ + virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0; + + /*! + * \brief Restore from a serialized protobuf file + * \param filename Filename that want to restore from + * \return true if succeeded + */ + virtual bool LoadModelFromProto(const char* filename) = 0; + #endif // USE_PROTO + /*! * \brief Calculate feature importances * \param num_iteration Number of model that want to use for feature importance, -1 means use all @@ -251,23 +271,17 @@ class LIGHTGBM_EXPORT Boosting { /*! \brief Disable copy */ Boosting(const Boosting&) = delete; - static bool LoadFileToBoosting(Boosting* boosting, const char* filename); + static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename); /*! * \brief Create boosting object * \param type Type of boosting + * \param format Format of model * \param config config for boosting * \param filename name of model file, if existing will continue to train from this model * \return The boosting object */ - static Boosting* CreateBoosting(const std::string& type, const char* filename); - - /*! - * \brief Create boosting object from model file - * \param filename name of model file - * \return The boosting object - */ - static Boosting* CreateBoosting(const char* filename); + static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename); }; diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 4a8dfccb788b..8fc7091aaaa1 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -105,6 +105,7 @@ struct IOConfig: public ConfigBase { std::string output_result = "LightGBM_predict_result.txt"; std::string convert_model = "gbdt_prediction.cpp"; std::string input_model = ""; + std::string model_format = "text"; int verbosity = 1; int num_iteration_predict = -1; bool is_pre_partition = false; @@ -445,7 +446,7 @@ struct ParameterAlias { const std::unordered_set parameter_set({ "config", "config_file", "task", "device", "num_threads", "seed", "boosting_type", "objective", "data", - "output_model", "input_model", "output_result", "valid_data", + "output_model", "input_model", "output_result", "model_format", "valid_data", "is_enable_sparse", "is_pre_partition", "is_training_metric", "ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf", "num_leaves", "feature_fraction", "num_iterations", diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index d40f19ce2930..471b8370b428 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -3,6 +3,9 @@ #include #include +#ifdef USE_PROTO +#include "model.pb.h" +#endif // USE_PROTO #include #include @@ -30,6 +33,13 @@ class Tree { * \param str Model string */ explicit Tree(const std::string& str); + #ifdef USE_PROTO + /*! + * \brief Construtor, from a protobuf object + * \param model_tree Model protobuf object + */ + explicit Tree(const Model_Tree& model_tree); + #endif // USE_PROTO ~Tree(); @@ -165,6 +175,11 @@ class Tree { /*! \brief Serialize this object to if-else statement*/ std::string ToIfElse(int index, bool is_predict_leaf_index) const; + #ifdef USE_PROTO + /*! \brief Serialize this object to protobuf object*/ + void ToProto(Model_Tree& model_tree) const; + #endif // USE_PROTO + inline static bool IsZero(double fval) { if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) { return true; diff --git a/proto/model.proto b/proto/model.proto new file mode 100644 index 000000000000..7d62c1811f36 --- /dev/null +++ b/proto/model.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package LightGBM; + +message Model { + string name = 1; + uint32 num_class = 2; + uint32 num_tree_per_iteration = 3; + uint32 label_index = 4; + uint32 max_feature_idx = 5; + string objective = 6; + bool average_output = 7; + repeated string feature_names = 8; + repeated string feature_infos = 9; + message Tree { + uint32 num_leaves = 1; + uint32 num_cat = 2; + repeated uint32 split_feature = 3; + repeated double split_gain = 4; + repeated double threshold = 5; + repeated uint32 decision_type = 6; + repeated sint32 left_child = 7; + repeated sint32 right_child = 8; + repeated double leaf_value = 9; + repeated uint32 leaf_count = 10; + repeated double internal_value = 11; + repeated double internal_count = 12; + repeated sint32 cat_boundaries = 13; + repeated uint32 cat_threshold = 14; + double shrinkage = 15; + } + repeated Tree trees = 10; +} diff --git a/src/application/application.cpp b/src/application/application.cpp index 5458e2babf3a..c0637dea9b13 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -180,6 +180,7 @@ void Application::InitTrain() { // create boosting boosting_.reset( Boosting::CreateBoosting(config_.boosting_type, + config_.io_config.model_format.c_str(), config_.io_config.input_model.c_str())); // create objective function objective_fun_.reset( @@ -203,6 +204,26 @@ void Application::InitTrain() { void Application::Train() { Log::Info("Started training..."); boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model); + std::vector model_formats = Common::Split(config_.io_config.model_format.c_str(), ','); + bool save_with_multiple_format = (model_formats.size() > 1); + for (auto model_format: model_formats) { + std::string save_file_name = config_.io_config.output_model; + if (save_with_multiple_format) { + // use suffix to distinguish different model format + save_file_name += "." + model_format; + } + if (model_format == std::string("text")) { + boosting_->SaveModelToFile(-1, save_file_name.c_str()); + } else if (model_format == std::string("proto")) { + #ifdef USE_PROTO + boosting_->SaveModelToProto(-1, save_file_name.c_str()); + #else + Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf."); + #endif // USE_PROTO + } else { + Log::Fatal("Unknown model format during saving: %s", model_format.c_str()); + } + } // convert model to if-else statement code if (config_.convert_model_language == std::string("cpp")) { boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); @@ -223,13 +244,15 @@ void Application::Predict() { void Application::InitPredict() { boosting_.reset( - Boosting::CreateBoosting(config_.io_config.input_model.c_str())); + Boosting::CreateBoosting("gbdt", config_.io_config.model_format.c_str(), + config_.io_config.input_model.c_str())); Log::Info("Finished initializing prediction"); } void Application::ConvertModel() { boosting_.reset( Boosting::CreateBoosting(config_.boosting_type, + config_.io_config.model_format.c_str(), config_.io_config.input_model.c_str())); boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); } diff --git a/src/boosting/boosting.cpp b/src/boosting/boosting.cpp index 82cfd62387aa..a21e8b28944a 100644 --- a/src/boosting/boosting.cpp +++ b/src/boosting/boosting.cpp @@ -12,21 +12,34 @@ std::string GetBoostingTypeFromModelFile(const char* filename) { return type; } -bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) { +bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) { if (boosting != nullptr) { - TextReader model_reader(filename, true); - model_reader.ReadAllLines(); - std::stringstream str_buf; - for (auto& line : model_reader.Lines()) { - str_buf << line << '\n'; + if (format == std::string("text")) { + TextReader model_reader(filename, true); + model_reader.ReadAllLines(); + std::stringstream str_buf; + for (auto& line : model_reader.Lines()) { + str_buf << line << '\n'; + } + if (!boosting->LoadModelFromString(str_buf.str())) { + return false; + } + } else if (format == std::string("proto")) { + #ifdef USE_PROTO + if (!boosting->LoadModelFromProto(filename)) { + return false; + } + #else + Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf."); + #endif // USE_PROTO + } else { + Log::Fatal("Unknown model format during loading: %s", format.c_str()); } - if (!boosting->LoadModelFromString(str_buf.str())) - return false; } return true; } -Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) { +Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) { if (filename == nullptr || filename[0] == '\0') { if (type == std::string("gbdt")) { return new GBDT(); @@ -41,8 +54,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename } } else { std::unique_ptr ret; - auto type_in_file = GetBoostingTypeFromModelFile(filename); - if (type_in_file == std::string("tree")) { + if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) { if (type == std::string("gbdt")) { ret.reset(new GBDT()); } else if (type == std::string("dart")) { @@ -54,24 +66,12 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename } else { Log::Fatal("unknown boosting type %s", type.c_str()); } - LoadFileToBoosting(ret.get(), filename); + LoadFileToBoosting(ret.get(), format, filename); } else { - Log::Fatal("unknown submodel type in model file %s", filename); + Log::Fatal("unknown model format or submodel type in model file %s", filename); } return ret.release(); } } -Boosting* Boosting::CreateBoosting(const char* filename) { - auto type = GetBoostingTypeFromModelFile(filename); - std::unique_ptr ret; - if (type == std::string("tree")) { - ret.reset(new GBDT()); - } else { - Log::Fatal("unknown submodel type in model file %s", filename); - } - LoadFileToBoosting(ret.get(), filename); - return ret.release(); -} - } // namespace LightGBM diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 16acdd669275..9d158cfbd9d4 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -352,7 +352,6 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { SaveModelToFile(-1, snapshot_out.c_str()); } } - SaveModelToFile(-1, model_output_path.c_str()); } double GBDT::BoostFromAverage() { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index da5f6ee944f8..5b4893e594a5 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -236,6 +236,22 @@ class GBDT: public GBDTBase { */ bool LoadModelFromString(const std::string& model_str) override; + #ifdef USE_PROTO + /*! + * \brief Save model with protobuf + * \param num_iterations Number of model that want to save, -1 means save all + * \param filename Filename that want to save to + */ + void SaveModelToProto(int num_iteration, const char* filename) const override; + + /*! + * \brief Restore from a serialized protobuf file + * \param filename Filename that want to restore from + * \return true if succeeded + */ + bool LoadModelFromProto(const char* filename) override; + #endif // USE_PROTO + /*! * \brief Calculate feature importances * \param num_iteration Number of model that want to use for feature importance, -1 means use all diff --git a/src/boosting/gbdt_model.cpp b/src/boosting/gbdt_model_text.cpp similarity index 100% rename from src/boosting/gbdt_model.cpp rename to src/boosting/gbdt_model_text.cpp diff --git a/src/c_api.cpp b/src/c_api.cpp index cd4c893bb071..3b2fdf021780 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -29,11 +29,7 @@ namespace LightGBM { class Booster { public: explicit Booster(const char* filename) { - boosting_.reset(Boosting::CreateBoosting(filename)); - } - - Booster() { - boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr)); + boosting_.reset(Boosting::CreateBoosting("gbdt", "text", filename)); } Booster(const Dataset* train_data, @@ -50,7 +46,7 @@ class Booster { please use continued train with input score"); } - boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr)); + boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "text", nullptr)); train_data_ = train_data; CreateObjectiveAndMetrics(); @@ -838,7 +834,7 @@ int LGBM_BoosterLoadModelFromString( int* out_num_iterations, BoosterHandle* out) { API_BEGIN(); - auto ret = std::unique_ptr(new Booster()); + auto ret = std::unique_ptr(new Booster(nullptr)); ret->LoadModelFromString(model_str); *out_num_iterations = ret->GetBoosting()->GetCurrentIteration(); *out = ret.release(); diff --git a/src/io/config.cpp b/src/io/config.cpp index 35468e733b52..07b5276369b0 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -269,6 +269,7 @@ void IOConfig::Set(const std::unordered_map& params) { GetString(params, "input_model", &input_model); GetString(params, "convert_model", &convert_model); GetString(params, "output_result", &output_result); + GetString(params, "model_format", &model_format); std::string tmp_str = ""; if (GetString(params, "valid_data", &tmp_str)) { valid_data_filenames = Common::Split(tmp_str.c_str(), ','); diff --git a/src/proto/gbdt_model_proto.cpp b/src/proto/gbdt_model_proto.cpp new file mode 100644 index 000000000000..fdd7dd0d3bfd --- /dev/null +++ b/src/proto/gbdt_model_proto.cpp @@ -0,0 +1,191 @@ +#include "../boosting/gbdt.h" + +#include +#include +#include +#include +#include + +namespace LightGBM { + +void GBDT::SaveModelToProto(int num_iteration, const char* filename) const { + LightGBM::Model model; + + model.set_name(SubModelName()); + model.set_num_class(num_class_); + model.set_num_tree_per_iteration(num_tree_per_iteration_); + model.set_label_index(label_idx_); + model.set_max_feature_idx(max_feature_idx_); + if (objective_function_ != nullptr) { + model.set_objective(objective_function_->ToString()); + } + model.set_average_output(average_output_); + for(auto feature_name: feature_names_) { + model.add_feature_names(feature_name); + } + for(auto feature_info: feature_infos_) { + model.add_feature_infos(feature_info); + } + + int num_used_model = static_cast(models_.size()); + if (num_iteration > 0) { + num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model); + } + for (int i = 0; i < num_used_model; ++i) { + models_[i]->ToProto(*model.add_trees()); + } + + std::filebuf fb; + fb.open(filename, std::ios::out | std::ios::binary); + std::ostream os(&fb); + if (!model.SerializeToOstream(&os)) { + Log::Fatal("Cannot serialize model to binary file."); + } + fb.close(); +} + +bool GBDT::LoadModelFromProto(const char* filename) { + models_.clear(); + LightGBM::Model model; + std::filebuf fb; + if (fb.open(filename, std::ios::in | std::ios::binary)) + { + std::istream is(&fb); + if (!model.ParseFromIstream(&is)) { + Log::Fatal("Cannot parse model from binary file."); + } + fb.close(); + } else { + Log::Fatal("Cannot open file: %s.", filename); + } + + num_class_ = model.num_class(); + num_tree_per_iteration_ = model.num_tree_per_iteration(); + label_idx_ = model.label_index(); + max_feature_idx_ = model.max_feature_idx(); + average_output_ = model.average_output(); + feature_names_.reserve(model.feature_names_size()); + for (auto feature_name: model.feature_names()) { + feature_names_.push_back(feature_name); + } + feature_infos_.reserve(model.feature_infos_size()); + for (auto feature_info: model.feature_infos()) { + feature_infos_.push_back(feature_info); + } + loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(model.objective())); + objective_function_ = loaded_objective_.get(); + + for (auto tree: model.trees()) { + models_.emplace_back(new Tree(tree)); + } + Log::Info("Finished loading %d models", models_.size()); + num_iteration_for_pred_ = static_cast(models_.size()) / num_tree_per_iteration_; + num_init_iteration_ = num_iteration_for_pred_; + iter_ = 0; + + return true; +} + +void Tree::ToProto(LightGBM::Model_Tree& model_tree) const { + + model_tree.set_num_leaves(num_leaves_); + model_tree.set_num_cat(num_cat_); + for (int i = 0; i < num_leaves_ - 1; ++i) { + model_tree.add_split_feature(split_feature_[i]); + model_tree.add_split_gain(split_gain_[i]); + model_tree.add_threshold(threshold_[i]); + model_tree.add_decision_type(decision_type_[i]); + model_tree.add_left_child(left_child_[i]); + model_tree.add_right_child(right_child_[i]); + model_tree.add_internal_value(internal_value_[i]); + model_tree.add_internal_count(internal_count_[i]); + } + + for (int i = 0; i < num_leaves_; ++i) { + model_tree.add_leaf_value(leaf_value_[i]); + model_tree.add_leaf_count(leaf_count_[i]); + } + + if (num_cat_ > 0) { + for (int i = 0; i < num_cat_ + 1; ++i) { + model_tree.add_cat_boundaries(cat_boundaries_[i]); + } + for (size_t i = 0; i < cat_threshold_.size(); ++i) { + model_tree.add_cat_threshold(cat_threshold_[i]); + } + } + model_tree.set_shrinkage(shrinkage_); +} + +Tree::Tree(const LightGBM::Model_Tree& model_tree) { + + num_leaves_ = model_tree.num_leaves(); + if (num_leaves_ <= 1) { return; } + num_cat_ = model_tree.num_cat(); + + leaf_value_.reserve(model_tree.leaf_value_size()); + for(auto leaf_value: model_tree.leaf_value()) { + leaf_value_.push_back(leaf_value); + } + + left_child_.reserve(model_tree.left_child_size()); + for(auto left_child: model_tree.left_child()) { + left_child_.push_back(left_child); + } + + right_child_.reserve(model_tree.right_child_size()); + for(auto right_child: model_tree.right_child()) { + right_child_.push_back(right_child); + } + + split_feature_.reserve(model_tree.split_feature_size()); + for(auto split_feature: model_tree.split_feature()) { + split_feature_.push_back(split_feature); + } + + threshold_.reserve(model_tree.threshold_size()); + for(auto threshold: model_tree.threshold()) { + threshold_.push_back(threshold); + } + + split_gain_.reserve(model_tree.split_gain_size()); + for(auto split_gain: model_tree.split_gain()) { + split_gain_.push_back(split_gain); + } + + internal_count_.reserve(model_tree.internal_count_size()); + for(auto internal_count: model_tree.internal_count()) { + internal_count_.push_back(internal_count); + } + + internal_value_.reserve(model_tree.internal_value_size()); + for(auto internal_value: model_tree.internal_value()) { + internal_value_.push_back(internal_value); + } + + leaf_count_.reserve(model_tree.leaf_count_size()); + for(auto leaf_count: model_tree.leaf_count()) { + leaf_count_.push_back(leaf_count); + } + + decision_type_.reserve(model_tree.decision_type_size()); + for(auto decision_type: model_tree.decision_type()) { + decision_type_.push_back(decision_type); + } + + if (num_cat_ > 0) { + cat_boundaries_.reserve(model_tree.cat_boundaries_size()); + for(auto cat_boundaries: model_tree.cat_boundaries()) { + cat_boundaries_.push_back(cat_boundaries); + } + + cat_threshold_.reserve(model_tree.cat_threshold_size()); + for(auto cat_threshold: model_tree.cat_threshold()) { + cat_threshold_.push_back(cat_threshold); + } + } + + shrinkage_ = model_tree.shrinkage(); +} + +} // namespace LightGBM diff --git a/tests/cpp_test/train.conf b/tests/cpp_test/train.conf index 2fb3da3b774f..0b4283e63ce1 100644 --- a/tests/cpp_test/train.conf +++ b/tests/cpp_test/train.conf @@ -3,7 +3,3 @@ data=../data/categorical.data app=binary num_trees=10 - -convert_model=../../src/boosting/gbdt_prediction.cpp - -convert_model_language=cpp diff --git a/windows/LightGBM.vcxproj b/windows/LightGBM.vcxproj index 532f758df37a..5731a92a0ffb 100644 --- a/windows/LightGBM.vcxproj +++ b/windows/LightGBM.vcxproj @@ -247,7 +247,7 @@ - + diff --git a/windows/LightGBM.vcxproj.filters b/windows/LightGBM.vcxproj.filters index 90b874bbb969..558dc03c3ce6 100644 --- a/windows/LightGBM.vcxproj.filters +++ b/windows/LightGBM.vcxproj.filters @@ -278,7 +278,7 @@ src - + src\boosting