From af70998e1c4a6d5d06361471ca2de65431504a70 Mon Sep 17 00:00:00 2001 From: Jincheng Chen Date: Thu, 4 Nov 2021 11:10:09 +0000 Subject: [PATCH] add customized parser support --- README.md | 2 + docs/Parameters.rst | 8 +++ include/LightGBM/boosting.h | 2 + include/LightGBM/config.h | 5 ++ include/LightGBM/dataset.h | 59 +++++++++++++++++++++ include/LightGBM/utils/common.h | 41 ++++++++++++++ src/application/predictor.hpp | 6 ++- src/boosting/gbdt.cpp | 2 + src/boosting/gbdt.h | 4 ++ src/boosting/gbdt_model_text.cpp | 29 +++++++++- src/io/config_auto.cpp | 4 ++ src/io/dataset_loader.cpp | 32 ++++++++--- src/io/parser.cpp | 53 ++++++++++++++++++ tests/python_package_test/test_utilities.py | 14 +++++ 14 files changed, 251 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 921bfb76f308..f689856265cd 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,8 @@ MLflow (experiment tracking, model monitoring framework): https://github.com/mlf `{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners +lightgbm-transform (transformation binding): https://github.com/microsoft/lightgbm-transform + Support ------- diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 088773d0690e..ab5ab165e550 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -850,6 +850,14 @@ Dataset Parameters - **Note**: setting this to ``true`` may lead to much slower text parsing +- ``parser_config_file`` :raw-html:`🔗︎`, default = ``""``, type = string + + - path to a ``.json`` file that specifies customized parser initialized configuration + + - see `lightgbm-transform `__ for usage examples + + - **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue `__ + Predict Parameters ~~~~~~~~~~~~~~~~~~ diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index ddbcdbc18e44..7530495c0e17 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -314,6 +314,8 @@ class LIGHTGBM_EXPORT Boosting { static Boosting* CreateBoosting(const std::string& type, const char* filename); virtual bool IsLinear() const { return false; } + + virtual std::string ParserConfigStr() const = 0; }; class GBDTBase : public Boosting { diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index c94420645cd9..6686692c740f 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -721,6 +721,11 @@ struct Config { // desc = **Note**: setting this to ``true`` may lead to much slower text parsing bool precise_float_parser = false; + // desc = path to a ``.json`` file that specifies customized parser initialized configuration + // desc = see `lightgbm-transform `__ for usage examples + // desc = **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue `__ + std::string parser_config_file = ""; + #pragma endregion #pragma region Predict Parameters diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index abef980f5fbe..cf19429322ee 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -254,6 +255,14 @@ class Parser { public: typedef const char* (*AtofFunc)(const char* p, double* out); + /*! \brief Default constructor */ + Parser() {} + + /*! + * \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string + */ + explicit Parser(std::string) {} + /*! \brief virtual destructor */ virtual ~Parser() {} @@ -271,12 +280,58 @@ class Parser { /*! * \brief Create an object of parser, will auto choose the format depend on file * \param filename One Filename of data + * \param header whether input file contains header * \param num_features Pass num_features of this data file if you know, <=0 means don't know * \param label_idx index of label column * \param precise_float_parser using precise floating point number parsing if true * \return Object of parser */ static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser); + + /*! + * \brief Create an object of parser, could use customized parser, or auto choose the format depend on file + * \param filename One Filename of data + * \param header whether input file contains header + * \param num_features Pass num_features of this data file if you know, <=0 means don't know + * \param label_idx index of label column + * \param precise_float_parser using precise floating point number parsing if true + * \param parser_config_str Customized parser config content + * \return Object of parser + */ + static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser, + std::string parser_config_str); + + /*! + * \brief Generate parser config str used for custom parser initialization, may save values of label id and header + * \param filename One Filename of data + * \param parser_config_filename One Filename of parser config + * \param header whether input file contains header + * \param label_idx index of label column + * \return Parser config str + */ + static std::string GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx); +}; + +/*! \brief Interface for parser factory, used by customized parser */ +class ParserFactory { + private: + ParserFactory() {} + std::map> object_map_; + + public: + ~ParserFactory() {} + static ParserFactory& getInstance(); + void Register(std::string class_name, std::function objc); + Parser* getObject(std::string class_name, std::string config_str); +}; + +/*! \brief Interface for parser reflector, used by customized parser */ +class ParserReflector { + public: + ParserReflector(std::string class_name, std::function objc) { + ParserFactory::getInstance().Register(class_name, objc); + } + virtual ~ParserReflector() {} }; /*! \brief The main class of data set, @@ -605,6 +660,9 @@ class Dataset { /*! \brief Get names of current data set */ inline const std::vector& feature_names() const { return feature_names_; } + /*! \brief Get content of parser config file */ + inline const std::string parser_config_str() const { return parser_config_str_; } + inline void set_feature_names(const std::vector& feature_names) { if (feature_names.size() != static_cast(num_total_features_)) { Log::Fatal("Size of feature_names error, should equal with total number of features"); @@ -722,6 +780,7 @@ class Dataset { /*! map feature (inner index) to its index in the list of numeric (non-categorical) features */ std::vector numeric_feature_map_; int num_numeric_features_; + std::string parser_config_str_; }; } // namespace LightGBM diff --git a/include/LightGBM/utils/common.h b/include/LightGBM/utils/common.h index af9810627b04..a44d9ad7f9a0 100644 --- a/include/LightGBM/utils/common.h +++ b/include/LightGBM/utils/common.h @@ -8,6 +8,7 @@ #if ((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__))) #include #endif +#include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #if (!((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__)))) #define FMT_HEADER_ONLY @@ -62,6 +64,8 @@ namespace LightGBM { namespace Common { +using json11::Json; + /*! * Imbues the stream with the C locale. */ @@ -200,6 +204,43 @@ inline static std::vector Split(const char* c_str, const char* deli return ret; } +inline static std::string LoadStringFromFile(const char* filename, int row_num = INT_MAX) { + if (filename == NULL || *filename == '\0') { + return ""; + } + std::stringstream ss; + Common::C_stringstream(ss); + std::ifstream fin(filename); + std::string line = ""; + int i = 0; + while (std::getline(fin, line) && i++ < row_num) { + ss << line << "\n"; + } + return ss.str(); +} + +inline static std::string GetFromParserConfig(std::string config_str, std::string key) { + // parser config should follow json format. + std::string err; + Json config_json = Json::parse(config_str, &err); + if (!err.empty()) { + Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str()); + } + return config_json[key].string_value(); +} + +inline static std::string SaveToParserConfig(std::string config_str, std::string key, std::string value) { + std::string err; + Json config_json = Json::parse(config_str, &err); + if (!err.empty()) { + Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str()); + } + CHECK(config_json.is_object()); + std::map config_map = config_json.object_items(); + config_map.insert(std::pair(key, Json(value))); + return Json(config_map).dump(); +} + template inline static const char* Atoi(const char* p, T* out) { int sign; diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index dff23add2df5..d1a8aca4d041 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -167,7 +168,7 @@ class Predictor { } auto label_idx = header ? -1 : boosting_->LabelIdx(); auto parser = std::unique_ptr(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx, - precise_float_parser)); + precise_float_parser, boosting_->ParserConfigStr())); if (parser == nullptr) { Log::Fatal("Could not recognize the data format of data file %s", data_filename); @@ -179,7 +180,8 @@ class Predictor { TextReader predict_data_reader(data_filename, header); std::vector feature_remapper(parser->NumFeatures(), -1); bool need_adjust = false; - if (header) { + // skip raw feature remapping if trained model has parser config str which may contain actual feature names. + if (header && boosting_->ParserConfigStr().empty()) { std::string first_line = predict_data_reader.first_line(); std::vector header_words = Common::Split(first_line.c_str(), "\t,"); std::unordered_map header_mapper; diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index d393d46d5133..e24cb52f10d9 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -120,6 +120,8 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective feature_names_ = train_data_->feature_names(); feature_infos_ = train_data_->feature_infos(); monotone_constraints_ = config->monotone_constraints; + // get parser config file content + parser_config_str_ = train_data_->parser_config_str(); // if need bagging, create buffer ResetBaggingConfig(config_.get(), true); diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 472ea1707104..efeacfbfaef0 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -394,6 +394,8 @@ class GBDT : public GBDTBase { bool IsLinear() const override { return linear_tree_; } + inline std::string ParserConfigStr() const override {return parser_config_str_;} + protected: virtual bool GetIsConstHessian(const ObjectiveFunction* objective_function) { if (objective_function != nullptr) { @@ -483,6 +485,8 @@ class GBDT : public GBDTBase { std::vector> models_; /*! \brief Max feature index of training data*/ int max_feature_idx_; + /*! \brief Parser config file content */ + std::string parser_config_str_ = ""; #ifdef USE_CUDA /*! \brief First order derivative of training data */ diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index 0e296c1ec984..f8467e8fb98e 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -399,6 +399,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int ss << loaded_parameter_ << "\n"; ss << "end of parameters" << '\n'; } + if (!parser_config_str_.empty()) { + ss << "\nparser:" << '\n'; + ss << parser_config_str_ << "\n"; + ss << "end of parser" << '\n'; + } return ss.str(); } @@ -568,7 +573,7 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { num_iteration_for_pred_ = static_cast(models_.size()) / num_tree_per_iteration_; num_init_iteration_ = num_iteration_for_pred_; iter_ = 0; - bool is_inparameter = false; + bool is_inparameter = false, is_inparser = false; std::stringstream ss; Common::C_stringstream(ss); while (p < end) { @@ -594,6 +599,28 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { if (!ss.str().empty()) { loaded_parameter_ = ss.str(); } + ss.clear(); + ss.str(""); + while (p < end) { + auto line_len = Common::GetLine(p); + if (line_len > 0) { + std::string cur_line(p, line_len); + if (cur_line == std::string("parser:")) { + is_inparser = true; + } else if (cur_line == std::string("end of parser")) { + p += line_len; + p = Common::SkipNewLine(p); + break; + } else if (is_inparser) { + ss << cur_line << "\n"; + } + } + p += line_len; + p = Common::SkipNewLine(p); + } + parser_config_str_ = ss.str(); + ss.clear(); + ss.str(""); return true; } diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 4e3f000a88f5..05edc7b44ff2 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -271,6 +271,7 @@ const std::unordered_set& Config::parameter_set() { "forcedbins_filename", "save_binary", "precise_float_parser", + "parser_config_file", "start_iteration_predict", "num_iteration_predict", "predict_raw_score", @@ -539,6 +540,8 @@ void Config::GetMembersFromString(const std::unordered_mapparser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_); auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_, - config_.precise_float_parser)); + config_.precise_float_parser, dataset->parser_config_str_)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename); } @@ -257,8 +275,6 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac return dataset.release(); } - - Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) { data_size_t num_global_data = 0; std::vector used_data_indices; @@ -269,7 +285,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, auto bin_filename = CheckCanLoadFromBin(filename); if (bin_filename.size() == 0) { auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_, - config_.precise_float_parser)); + config_.precise_float_parser, dataset->parser_config_str_)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename); } @@ -1010,7 +1026,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, categorical_features_); // check the range of label_idx, weight_idx and group_idx - CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); + // skip label check if user input parser config file, + // because label id is got from raw features while dataset features are consistent with customized parser. + if (dataset->parser_config_str_.empty()) { + CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); + } CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_); CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_); @@ -1383,8 +1403,6 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { } } - - std::vector> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features, const std::unordered_set& categorical_features) { std::vector> forced_bins(num_total_features, std::vector()); diff --git a/src/io/parser.cpp b/src/io/parser.cpp index 58f2d5b94467..3910e41ebdd0 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -4,8 +4,10 @@ */ #include "parser.hpp" +#include #include #include +#include #include namespace LightGBM { @@ -230,6 +232,30 @@ DataType GetDataType(const char* filename, bool header, return type; } +// parser factory implementation. +ParserFactory& ParserFactory::getInstance() { + static ParserFactory factory; + return factory; +} + +void ParserFactory::Register(std::string class_name, std::function m_objc) { + if (m_objc) { + object_map_.insert( + std::map>::value_type(class_name, m_objc)); + } +} + +Parser* ParserFactory::getObject(std::string class_name, std::string config_str) { + std::map>::const_iterator iter = + object_map_.find(class_name); + if (iter != object_map_.end()) { + return iter->second(config_str); + } else { + Log::Fatal("Cannot find parser class '%s', please register first or check config format.", class_name.c_str()); + return nullptr; + } +} + Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser) { const int n_read_line = 32; auto lines = ReadKLineFromFile(filename, header, n_read_line); @@ -258,4 +284,31 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features return ret.release(); } +Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser, std::string parser_config_str) { + // customized parser add-on. + if (!parser_config_str.empty()) { + std::unique_ptr ret; + std::string class_name = Common::GetFromParserConfig(parser_config_str, "className"); + Log::Info("Custom parser class name: %s", class_name.c_str()); + Parser* p = ParserFactory::getInstance().getObject(class_name, parser_config_str); + ret.reset(p); + return ret.release(); + } + return CreateParser(filename, header, num_features, label_idx, precise_float_parser); +} + +std::string Parser::GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx) { + std::string parser_config_str = Common::LoadStringFromFile(parser_config_filename); + if (!parser_config_str.empty()) { + // save header to parser config in case needed. + if (header && Common::GetFromParserConfig(parser_config_str, "header").empty()) { + parser_config_str = Common::SaveToParserConfig(parser_config_str, "header", Common::LoadStringFromFile(filename, 1)); + } + // save label id to parser config in case needed. + if (Common::GetFromParserConfig(parser_config_str, "labelId").empty()) { + parser_config_str = Common::SaveToParserConfig(parser_config_str, "labelId", std::to_string(label_idx)); + } + } + return parser_config_str; +} } // namespace LightGBM diff --git a/tests/python_package_test/test_utilities.py b/tests/python_package_test/test_utilities.py index be57d585f695..ed62422e063b 100644 --- a/tests/python_package_test/test_utilities.py +++ b/tests/python_package_test/test_utilities.py @@ -1,7 +1,9 @@ # coding: utf-8 import logging +from pathlib import Path import numpy as np +import pytest import lightgbm as lgb @@ -97,3 +99,15 @@ def dummy_metric(_, __): actual_log_wo_gpu_stuff.append(line) assert "\n".join(actual_log_wo_gpu_stuff) == expected_log + + +def test_smoke_custom_parser(tmp_path): + data_path = Path(__file__).absolute().parents[2] / 'examples/binary_classification/binary.train' + parser_config_file = tmp_path / 'parser.ini' + with open(parser_config_file, 'w') as fout: + fout.write("{\"className\": \"dummy\", \"id\":\"1\"}") + + data = lgb.Dataset(data_path, params={"parser_config_file": parser_config_file}) + with pytest.raises(lgb.basic.LightGBMError, + match="Cannot find parser class 'dummy', please register first or check config format"): + lgb.train({}, data)