Skip to content

Commit

Permalink
add customized parser support
Browse files Browse the repository at this point in the history
  • Loading branch information
chjinche committed Nov 9, 2021
1 parent b1facf5 commit af70998
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ MLflow (experiment tracking, model monitoring framework): https://github.com/mlf

`{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners

lightgbm-transform (transformation binding): https://github.com/microsoft/lightgbm-transform

Support
-------

Expand Down
8 changes: 8 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,14 @@ Dataset Parameters

- **Note**: setting this to ``true`` may lead to much slower text parsing

- ``parser_config_file`` :raw-html:`<a id="parser_config_file" title="Permalink to this parameter" href="#parser_config_file">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string

- path to a ``.json`` file that specifies customized parser initialized configuration

- see `lightgbm-transform <https://github.com/microsoft/lightgbm-transform>`__ for usage examples

- **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue <https://github.com/microsoft/lightgbm-transform/issues>`__

Predict Parameters
~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ class LIGHTGBM_EXPORT Boosting {
static Boosting* CreateBoosting(const std::string& type, const char* filename);

virtual bool IsLinear() const { return false; }

virtual std::string ParserConfigStr() const = 0;
};

class GBDTBase : public Boosting {
Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,11 @@ struct Config {
// desc = **Note**: setting this to ``true`` may lead to much slower text parsing
bool precise_float_parser = false;

// desc = path to a ``.json`` file that specifies customized parser initialized configuration
// desc = see `lightgbm-transform <https://github.com/microsoft/lightgbm-transform>`__ for usage examples
// desc = **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issue <https://github.com/microsoft/lightgbm-transform/issues>`__
std::string parser_config_file = "";

#pragma endregion

#pragma region Predict Parameters
Expand Down
59 changes: 59 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <string>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <unordered_set>
Expand Down Expand Up @@ -254,6 +255,14 @@ class Parser {
public:
typedef const char* (*AtofFunc)(const char* p, double* out);

/*! \brief Default constructor */
Parser() {}

/*!
* \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string
*/
explicit Parser(std::string) {}

/*! \brief virtual destructor */
virtual ~Parser() {}

Expand All @@ -271,12 +280,58 @@ class Parser {
/*!
* \brief Create an object of parser, will auto choose the format depend on file
* \param filename One Filename of data
* \param header whether input file contains header
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param label_idx index of label column
* \param precise_float_parser using precise floating point number parsing if true
* \return Object of parser
*/
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser);

/*!
* \brief Create an object of parser, could use customized parser, or auto choose the format depend on file
* \param filename One Filename of data
* \param header whether input file contains header
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param label_idx index of label column
* \param precise_float_parser using precise floating point number parsing if true
* \param parser_config_str Customized parser config content
* \return Object of parser
*/
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser,
std::string parser_config_str);

/*!
* \brief Generate parser config str used for custom parser initialization, may save values of label id and header
* \param filename One Filename of data
* \param parser_config_filename One Filename of parser config
* \param header whether input file contains header
* \param label_idx index of label column
* \return Parser config str
*/
static std::string GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx);
};

/*! \brief Interface for parser factory, used by customized parser */
class ParserFactory {
private:
ParserFactory() {}
std::map<std::string, std::function<Parser*(std::string)>> object_map_;

public:
~ParserFactory() {}
static ParserFactory& getInstance();
void Register(std::string class_name, std::function<Parser*(std::string)> objc);
Parser* getObject(std::string class_name, std::string config_str);
};

/*! \brief Interface for parser reflector, used by customized parser */
class ParserReflector {
public:
ParserReflector(std::string class_name, std::function<Parser*(std::string)> objc) {
ParserFactory::getInstance().Register(class_name, objc);
}
virtual ~ParserReflector() {}
};

/*! \brief The main class of data set,
Expand Down Expand Up @@ -605,6 +660,9 @@ class Dataset {
/*! \brief Get names of current data set */
inline const std::vector<std::string>& feature_names() const { return feature_names_; }

/*! \brief Get content of parser config file */
inline const std::string parser_config_str() const { return parser_config_str_; }

inline void set_feature_names(const std::vector<std::string>& feature_names) {
if (feature_names.size() != static_cast<size_t>(num_total_features_)) {
Log::Fatal("Size of feature_names error, should equal with total number of features");
Expand Down Expand Up @@ -722,6 +780,7 @@ class Dataset {
/*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
std::vector<int> numeric_feature_map_;
int num_numeric_features_;
std::string parser_config_str_;
};

} // namespace LightGBM
Expand Down
41 changes: 41 additions & 0 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#if ((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__)))
#include <LightGBM/utils/common_legacy_solaris.h>
#endif
#include <LightGBM/utils/json11.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>

Expand All @@ -30,6 +31,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <fstream>

#if (!((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__))))
#define FMT_HEADER_ONLY
Expand Down Expand Up @@ -62,6 +64,8 @@ namespace LightGBM {

namespace Common {

using json11::Json;

/*!
* Imbues the stream with the C locale.
*/
Expand Down Expand Up @@ -200,6 +204,43 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret;
}

inline static std::string LoadStringFromFile(const char* filename, int row_num = INT_MAX) {
if (filename == NULL || *filename == '\0') {
return "";
}
std::stringstream ss;
Common::C_stringstream(ss);
std::ifstream fin(filename);
std::string line = "";
int i = 0;
while (std::getline(fin, line) && i++ < row_num) {
ss << line << "\n";
}
return ss.str();
}

inline static std::string GetFromParserConfig(std::string config_str, std::string key) {
// parser config should follow json format.
std::string err;
Json config_json = Json::parse(config_str, &err);
if (!err.empty()) {
Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str());
}
return config_json[key].string_value();
}

inline static std::string SaveToParserConfig(std::string config_str, std::string key, std::string value) {
std::string err;
Json config_json = Json::parse(config_str, &err);
if (!err.empty()) {
Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str());
}
CHECK(config_json.is_object());
std::map<std::string, Json> config_map = config_json.object_items();
config_map.insert(std::pair<std::string, Json>(key, Json(value)));
return Json(config_map).dump();
}

template<typename T>
inline static const char* Atoi(const char* p, T* out) {
int sign;
Expand Down
6 changes: 4 additions & 2 deletions src/application/predictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <LightGBM/boosting.h>
#include <LightGBM/dataset.h>
#include <LightGBM/meta.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

Expand Down Expand Up @@ -167,7 +168,7 @@ class Predictor {
}
auto label_idx = header ? -1 : boosting_->LabelIdx();
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx,
precise_float_parser));
precise_float_parser, boosting_->ParserConfigStr()));

if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
Expand All @@ -179,7 +180,8 @@ class Predictor {
TextReader<data_size_t> predict_data_reader(data_filename, header);
std::vector<int> feature_remapper(parser->NumFeatures(), -1);
bool need_adjust = false;
if (header) {
// skip raw feature remapping if trained model has parser config str which may contain actual feature names.
if (header && boosting_->ParserConfigStr().empty()) {
std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
std::unordered_map<std::string, int> header_mapper;
Expand Down
2 changes: 2 additions & 0 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
monotone_constraints_ = config->monotone_constraints;
// get parser config file content
parser_config_str_ = train_data_->parser_config_str();

// if need bagging, create buffer
ResetBaggingConfig(config_.get(), true);
Expand Down
4 changes: 4 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ class GBDT : public GBDTBase {

bool IsLinear() const override { return linear_tree_; }

inline std::string ParserConfigStr() const override {return parser_config_str_;}

protected:
virtual bool GetIsConstHessian(const ObjectiveFunction* objective_function) {
if (objective_function != nullptr) {
Expand Down Expand Up @@ -483,6 +485,8 @@ class GBDT : public GBDTBase {
std::vector<std::unique_ptr<Tree>> models_;
/*! \brief Max feature index of training data*/
int max_feature_idx_;
/*! \brief Parser config file content */
std::string parser_config_str_ = "";

#ifdef USE_CUDA
/*! \brief First order derivative of training data */
Expand Down
29 changes: 28 additions & 1 deletion src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int
ss << loaded_parameter_ << "\n";
ss << "end of parameters" << '\n';
}
if (!parser_config_str_.empty()) {
ss << "\nparser:" << '\n';
ss << parser_config_str_ << "\n";
ss << "end of parser" << '\n';
}
return ss.str();
}

Expand Down Expand Up @@ -568,7 +573,7 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
bool is_inparameter = false;
bool is_inparameter = false, is_inparser = false;
std::stringstream ss;
Common::C_stringstream(ss);
while (p < end) {
Expand All @@ -594,6 +599,28 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
if (!ss.str().empty()) {
loaded_parameter_ = ss.str();
}
ss.clear();
ss.str("");
while (p < end) {
auto line_len = Common::GetLine(p);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (cur_line == std::string("parser:")) {
is_inparser = true;
} else if (cur_line == std::string("end of parser")) {
p += line_len;
p = Common::SkipNewLine(p);
break;
} else if (is_inparser) {
ss << cur_line << "\n";
}
}
p += line_len;
p = Common::SkipNewLine(p);
}
parser_config_str_ = ss.str();
ss.clear();
ss.str("");
return true;
}

Expand Down
4 changes: 4 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"forcedbins_filename",
"save_binary",
"precise_float_parser",
"parser_config_file",
"start_iteration_predict",
"num_iteration_predict",
"predict_raw_score",
Expand Down Expand Up @@ -539,6 +540,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetBool(params, "precise_float_parser", &precise_float_parser);

GetString(params, "parser_config_file", &parser_config_file);

GetInt(params, "start_iteration_predict", &start_iteration_predict);

GetInt(params, "num_iteration_predict", &num_iteration_predict);
Expand Down Expand Up @@ -722,6 +725,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
str_buf << "[precise_float_parser: " << precise_float_parser << "]\n";
str_buf << "[parser_config_file: " << parser_config_file << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
Expand Down
Loading

0 comments on commit af70998

Please sign in to comment.