From 75b5574260c9706077726cde674466058edef386 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Sun, 11 Jul 2021 03:20:47 -0700 Subject: [PATCH] Remove the use of dmlc::JSONWriter, dmlc::Stream, and dmlc serializer (#289) --- include/treelite/annotator.h | 16 ++- include/treelite/tree.h | 24 ++-- include/treelite/tree_impl.h | 4 +- src/CMakeLists.txt | 2 +- src/annotator.cc | 66 ++++++---- src/c_api/c_api.cc | 6 +- src/compiler/ast/ast.h | 5 +- src/compiler/ast/builder.h | 3 +- src/compiler/ast/load_data_counts.cc | 18 +-- src/compiler/ast_native.cc | 47 +++++--- src/compiler/failsafe.cc | 39 ++++-- src/frontend/builder.cc | 2 +- src/frontend/lightgbm.cc | 56 ++------- src/frontend/sklearn.cc | 12 +- src/frontend/xgboost.cc | 65 +++++----- src/frontend/xgboost_json.cc | 5 +- src/gtil/predict.cc | 4 +- src/json_serializer.cc | 174 +++++++++++++++++++++++++++ src/reference_serializer.cc | 82 ------------- tests/cpp/test_serializer.cc | 12 +- 20 files changed, 381 insertions(+), 261 deletions(-) create mode 100644 src/json_serializer.cc delete mode 100644 src/reference_serializer.cc diff --git a/include/treelite/annotator.h b/include/treelite/annotator.h index 2e5972c1..fce0ac08 100644 --- a/include/treelite/annotator.h +++ b/include/treelite/annotator.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2017-2020 by Contributors + * Copyright (c) 2017-2021 by Contributors * \file annotator.h * \author Hyunsu Cho * \brief Branch annotation tools @@ -9,7 +9,11 @@ #include #include +#include +#include #include +#include +#include namespace treelite { @@ -29,12 +33,12 @@ class BranchAnnotator { * \brief load branch annotation from a JSON file * \param fi input stream */ - void Load(dmlc::Stream* fi); + void Load(std::istream& fi); /*! * \brief save branch annotation to a JSON file * \param fo output stream */ - void Save(dmlc::Stream* fo) const; + void Save(std::ostream& fo) const; /*! * \brief fetch branch annotation. * Usage example: @@ -48,12 +52,12 @@ class BranchAnnotator { * \endcode * \return branch annotation in 2D vector */ - inline std::vector> Get() const { - return counts; + inline std::vector> Get() const { + return counts_; } private: - std::vector> counts; + std::vector> counts_; }; } // namespace treelite diff --git a/include/treelite/tree.h b/include/treelite/tree.h index 99619ffe..e7f39315 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -27,14 +28,6 @@ #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 -/* Foward declarations */ -namespace dmlc { - -class Stream; -float stof(const std::string& value, std::size_t* pos); - -} // namespace dmlc - namespace treelite { // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation @@ -161,7 +154,7 @@ enum class TaskType : uint8_t { }; /*! \brief Group of parameters that are dependent on the choice of the task type. */ -struct TaskParameter { +struct TaskParam { enum class OutputType : uint8_t { kFloat = 0, kInt = 1 }; /*! \brief The type of output from each leaf node. */ OutputType output_type; @@ -190,7 +183,7 @@ struct TaskParameter { unsigned int leaf_vector_size; }; -static_assert(std::is_pod::value, "TaskParameter must be POD type"); +static_assert(std::is_pod::value, "TaskParameter must be POD type"); /*! \brief in-memory representation of a decision tree */ template @@ -289,6 +282,9 @@ class Tree { ContiguousArray matching_categories_; ContiguousArray matching_categories_offset_; + template + friend void SerializeTreeToJSON(WriterType& writer, const Tree& tree); + // allocate a new node inline int AllocNode(); @@ -562,8 +558,6 @@ class Tree { node.gain_ = gain; node.gain_present_ = true; } - - void ReferenceSerialize(dmlc::Stream* fo) const; }; struct ModelParam { @@ -656,7 +650,7 @@ class Model { virtual std::size_t GetNumTree() const = 0; virtual void SetTreeLimit(std::size_t limit) = 0; - virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0; + virtual void SerializeToJSON(std::ostream& fo) const = 0; /* In-memory serialization, zero-copy */ std::vector GetPyBuffer(); @@ -676,7 +670,7 @@ class Model { /*! \brief whether to average tree outputs */ bool average_tree_output; /*! \brief Group of parameters that are specific to the particular task type */ - TaskParameter task_param; + TaskParam task_param; /*! \brief extra parameters */ ModelParam param; @@ -712,7 +706,7 @@ class ModelImpl : public Model { ModelImpl(ModelImpl&&) noexcept = default; ModelImpl& operator=(ModelImpl&&) noexcept = default; - void ReferenceSerialize(dmlc::Stream* fo) const override; + void SerializeToJSON(std::ostream& fo) const override; inline std::size_t GetNumTree() const override { return trees.size(); } diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index fcfa586c..05e38fa1 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -296,9 +296,9 @@ ModelParam::InitAllowUnknown(const Container& kwargs) { TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1); this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0'; } else if (e.first == "sigmoid_alpha") { - this->sigmoid_alpha = dmlc::stof(e.second, nullptr); + this->sigmoid_alpha = std::stof(e.second, nullptr); } else if (e.first == "global_bias") { - this->global_bias = dmlc::stof(e.second, nullptr); + this->global_bias = std::stof(e.second, nullptr); } } return unknowns; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7aaa9b42..b56222d5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -102,7 +102,7 @@ target_sources(objtreelite filesystem.cc optable.cc serializer.cc - reference_serializer.cc + json_serializer.cc ${PROJECT_SOURCE_DIR}/include/treelite/annotator.h ${PROJECT_SOURCE_DIR}/include/treelite/base.h ${PROJECT_SOURCE_DIR}/include/treelite/c_api.h diff --git a/src/annotator.cc b/src/annotator.cc index d540861d..882ee615 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -8,7 +8,10 @@ #include #include #include -#include +#include +#include +#include +#include #include #include #include @@ -23,7 +26,7 @@ union Entry { template void Traverse_(const treelite::Tree& tree, - const Entry* data, int nid, size_t* out_counts) { + const Entry* data, int nid, uint64_t* out_counts) { ++out_counts[nid]; if (!tree.IsLeaf(nid)) { const unsigned split_index = tree.SplitIndex(nid); @@ -58,7 +61,7 @@ void Traverse_(const treelite::Tree& tree, template void Traverse(const treelite::Tree& tree, - const Entry* data, size_t* out_counts) { + const Entry* data, uint64_t* out_counts) { Traverse_(tree, data, 0, out_counts); } @@ -66,7 +69,7 @@ template inline void ComputeBranchLoopImpl( const treelite::ModelImpl& model, const treelite::DenseDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, size_t* counts_tloc) { + const size_t* count_row_ptr, uint64_t* counts_tloc) { std::vector> inst(nthread * dmat->num_col, {-1}); const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); @@ -103,7 +106,7 @@ template inline void ComputeBranchLoopImpl( const treelite::ModelImpl& model, const treelite::CSRDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, size_t* counts_tloc) { + const size_t* count_row_ptr, uint64_t* counts_tloc) { std::vector> inst(nthread * dmat->num_col, {-1}); const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); @@ -136,7 +139,7 @@ class ComputeBranchLoopDispatcherWithDenseDMatrix { inline static void Dispatch( const treelite::ModelImpl& model, const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, size_t* counts_tloc) { + const size_t* count_row_ptr, uint64_t* counts_tloc) { const auto* dmat_ = static_cast*>(dmat); CHECK(dmat_) << "Dangling data matrix reference detected"; ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc); @@ -150,7 +153,7 @@ class ComputeBranchLoopDispatcherWithCSRDMatrix { inline static void Dispatch( const treelite::ModelImpl& model, const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, size_t* counts_tloc) { + const size_t* count_row_ptr, uint64_t* counts_tloc) { const auto* dmat_ = static_cast*>(dmat); CHECK(dmat_) << "Dangling data matrix reference detected"; ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc); @@ -161,7 +164,7 @@ template inline void ComputeBranchLoop(const treelite::ModelImpl& model, const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, const size_t* count_row_ptr, - size_t* counts_tloc) { + uint64_t* counts_tloc) { switch (dmat->GetType()) { case treelite::DMatrixType::kDense: { treelite::DispatchWithTypeInfo( @@ -189,9 +192,9 @@ inline void AnnotateImpl( const treelite::ModelImpl& model, const treelite::DMatrix* dmat, int nthread, int verbose, - std::vector>* out_counts) { - std::vector new_counts; - std::vector counts_tloc; + std::vector>* out_counts) { + std::vector new_counts; + std::vector counts_tloc; std::vector count_row_ptr; count_row_ptr = {0}; @@ -224,7 +227,7 @@ AnnotateImpl( } // change layout of counts - std::vector>& counts = *out_counts; + std::vector>& counts = *out_counts; for (size_t i = 0; i < ntree; ++i) { counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]); } @@ -234,22 +237,43 @@ void BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) { TypeInfo threshold_type = model.GetThresholdType(); model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) { - AnnotateImpl(handle, dmat, nthread, verbose, &this->counts); + AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_); }); } void -BranchAnnotator::Load(dmlc::Stream* fi) { - dmlc::istream is(fi); - std::unique_ptr reader(new dmlc::JSONReader(&is)); - reader->Read(&counts); +BranchAnnotator::Load(std::istream& fi) { + rapidjson::IStreamWrapper is(fi); + + rapidjson::Document doc; + doc.ParseStream(is); + + std::string err_msg = "JSON file must contain a list of lists of integers"; + CHECK(doc.IsArray()) << err_msg; + counts_.clear(); + for (const auto& node_cnt : doc.GetArray()) { + CHECK(node_cnt.IsArray()) << err_msg; + counts_.emplace_back(); + for (const auto& e : node_cnt.GetArray()) { + counts_.back().push_back(e.GetUint64()); + } + } } void -BranchAnnotator::Save(dmlc::Stream* fo) const { - dmlc::ostream os(fo); - std::unique_ptr writer(new dmlc::JSONWriter(&os)); - writer->Write(counts); +BranchAnnotator::Save(std::ostream& fo) const { + rapidjson::OStreamWrapper os(fo); + rapidjson::Writer writer(os); + + writer.StartArray(); + for (const auto& node_cnt : counts_) { + writer.StartArray(); + for (auto e : node_cnt) { + writer.Uint64(e); + } + writer.EndArray(); + } + writer.EndArray(); } } // namespace treelite diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index b25fbd13..339b1470 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -19,6 +19,8 @@ #include #include #include +#include +#include using namespace treelite; @@ -51,8 +53,8 @@ int TreeliteAnnotationSave(AnnotationHandle handle, const char* path) { API_BEGIN(); const BranchAnnotator* annotator = static_cast(handle); - std::unique_ptr fo(dmlc::Stream::Create(path, "w")); - annotator->Save(fo.get()); + std::ofstream fo(path); + annotator->Save(fo); API_END(); } diff --git a/src/compiler/ast/ast.h b/src/compiler/ast/ast.h index c37cd513..ee6f1716 100644 --- a/src/compiler/ast/ast.h +++ b/src/compiler/ast/ast.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2017-2020 by Contributors + * Copyright (c) 2017-2021 by Contributors * \file ast.h * \brief Definition for AST classes * \author Hyunsu Cho @@ -14,6 +14,7 @@ #include #include #include +#include namespace treelite { namespace compiler { @@ -24,7 +25,7 @@ class ASTNode { std::vector children; int node_id; int tree_id; - dmlc::optional data_count; + dmlc::optional data_count; dmlc::optional sum_hess; virtual std::string GetDump() const = 0; virtual ~ASTNode() = 0; // force ASTNode to be abstract class diff --git a/src/compiler/ast/builder.h b/src/compiler/ast/builder.h index 169e6ca8..3284e2cb 100644 --- a/src/compiler/ast/builder.h +++ b/src/compiler/ast/builder.h @@ -13,6 +13,7 @@ #include #include #include +#include #include "./ast.h" namespace treelite { @@ -58,7 +59,7 @@ class ASTBuilder { /* \brief replace split thresholds with integers */ void QuantizeThresholds(); /* \brief Load data counts from annotation file */ - void LoadDataCounts(const std::vector>& counts); + void LoadDataCounts(const std::vector>& counts); /* * \brief Get a text representation of AST */ diff --git a/src/compiler/ast/load_data_counts.cc b/src/compiler/ast/load_data_counts.cc index 23693de2..61c79195 100644 --- a/src/compiler/ast/load_data_counts.cc +++ b/src/compiler/ast/load_data_counts.cc @@ -1,11 +1,11 @@ /*! - * Copyright (c) 2017-2020 by Contributors + * Copyright (c) 2017-2021 by Contributors * \file load_data_counts.cc * \brief AST manipulation logic to load data counts * \author Hyunsu Cho */ #include -#include +#include #include "./builder.h" namespace treelite { @@ -13,7 +13,7 @@ namespace compiler { DMLC_REGISTRY_FILE_TAG(load_data_counts); -static void load_data_counts(ASTNode* node, const std::vector>& counts) { +static void load_data_counts(ASTNode* node, const std::vector>& counts) { if (node->tree_id >= 0 && node->node_id >= 0) { node->data_count = counts[node->tree_id][node->node_id]; } @@ -25,14 +25,16 @@ static void load_data_counts(ASTNode* node, const std::vector void ASTBuilder::LoadDataCounts( - const std::vector>& counts) { + const std::vector>& counts) { load_data_counts(this->main_node, counts); } -template void ASTBuilder::LoadDataCounts(const std::vector>&); -template void ASTBuilder::LoadDataCounts(const std::vector>&); -template void ASTBuilder::LoadDataCounts(const std::vector>&); -template void ASTBuilder::LoadDataCounts(const std::vector>&); +template void ASTBuilder::LoadDataCounts( + const std::vector>&); +template void ASTBuilder::LoadDataCounts(const std::vector>&); +template void ASTBuilder::LoadDataCounts( + const std::vector>&); +template void ASTBuilder::LoadDataCounts(const std::vector>&); } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index 47b923c4..c20c544b 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -9,11 +9,15 @@ #include #include #include +#include +#include #include #include #include #include +#include #include +#include #include "./pred_transform.h" #include "./ast/builder.h" #include "./native/main_template.h" @@ -57,7 +61,7 @@ class ASTNativeCompiler : public Compiler { CHECK(model.task_type != TaskType::kMultiClfCategLeaf) << "Model task type unsupported by ASTNativeCompiler"; - CHECK(model.task_param.output_type == TaskParameter::OutputType::kFloat) + CHECK(model.task_param.output_type == TaskParam::OutputType::kFloat) << "ASTNativeCompiler only supports models with float output"; num_feature_ = model.num_feature; @@ -77,9 +81,8 @@ class ASTNativeCompiler : public Compiler { } if (param.annotate_in != "NULL") { BranchAnnotator annotator; - std::unique_ptr fi( - dmlc::Stream::Create(param.annotate_in.c_str(), "r")); - annotator.Load(fi.get()); + std::ifstream fi(param.annotate_in.c_str()); + annotator.Load(fi); const auto annotation = annotator.Get(); builder.LoadDataCounts(annotation); LOG(INFO) << "Loading node frequencies from `" @@ -105,23 +108,31 @@ class ASTNativeCompiler : public Compiler { { /* write recipe.json */ - std::vector> source_list; + rapidjson::StringBuffer os; + rapidjson::Writer writer(os); + + writer.StartObject(); + writer.Key("target"); + writer.String(param.native_lib_name.data(), param.native_lib_name.size()); + writer.Key("sources"); + writer.StartArray(); for (const auto& kv : files_) { if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) { const size_t line_count = std::count(kv.second.content.begin(), kv.second.content.end(), '\n'); - source_list.push_back({ {"name", - kv.first.substr(0, kv.first.length() - 2)}, - {"length", std::to_string(line_count)} }); + writer.StartObject(); + writer.Key("name"); + std::string name = kv.first.substr(0, kv.first.length() - 2); + writer.String(name.data(), name.size()); + writer.Key("length"); + writer.Uint64(line_count); + writer.EndObject(); } } - std::ostringstream oss; - std::unique_ptr writer(new dmlc::JSONWriter(&oss)); - writer->BeginObject(); - writer->WriteObjectKeyValue("target", param.native_lib_name); - writer->WriteObjectKeyValue("sources", source_list); - writer->EndObject(); - files_["recipe.json"] = CompiledModel::FileEntry(oss.str()); + writer.EndArray(); + writer.EndObject(); + + files_["recipe.json"] = CompiledModel::FileEntry(os.GetString()); } cm.files = std::move(files_); return cm; @@ -140,7 +151,7 @@ class ASTNativeCompiler : public Compiler { CompilerParam param; int num_feature_; TaskType task_type_; - TaskParameter task_param_; + TaskParam task_param_; std::string pred_transform_; float sigmoid_alpha_; float global_bias_; @@ -331,8 +342,8 @@ class ASTNativeCompiler : public Compiler { condition_with_na_check = ExtractCategoricalCondition(t2); } if (node->children[0]->data_count && node->children[1]->data_count) { - const size_t left_freq = node->children[0]->data_count.value(); - const size_t right_freq = node->children[1]->data_count.value(); + const uint64_t left_freq = node->children[0]->data_count.value(); + const uint64_t right_freq = node->children[1]->data_count.value(); condition_with_na_check = fmt::format(" {keyword}( {condition} ) ", "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"), diff --git a/src/compiler/failsafe.cc b/src/compiler/failsafe.cc index b1f00783..cce1ebae 100644 --- a/src/compiler/failsafe.cc +++ b/src/compiler/failsafe.cc @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include #include @@ -360,29 +362,42 @@ class FailSafeCompiler : public Compiler { { /* write recipe.json */ - std::vector> source_list; + rapidjson::StringBuffer os; + rapidjson::Writer writer(os); + + writer.StartObject(); + writer.Key("target"); + writer.String(param.native_lib_name.data(), param.native_lib_name.size()); + writer.Key("sources"); + writer.StartArray(); std::vector extra_file_list; for (const auto& kv : files_) { if (EndsWith(kv.first, ".c")) { const size_t line_count = std::count(kv.second.content.begin(), kv.second.content.end(), '\n'); - source_list.push_back({ {"name", - kv.first.substr(0, kv.first.length() - 2)}, - {"length", std::to_string(line_count)} }); + writer.StartObject(); + writer.Key("name"); + std::string name = kv.first.substr(0, kv.first.length() - 2); + writer.String(name.data(), name.size()); + writer.Key("length"); + writer.Uint64(line_count); + writer.EndObject(); } else if (EndsWith(kv.first, ".o")) { extra_file_list.push_back(kv.first); } } - std::ostringstream oss; - std::unique_ptr writer(new dmlc::JSONWriter(&oss)); - writer->BeginObject(); - writer->WriteObjectKeyValue("target", param.native_lib_name); - writer->WriteObjectKeyValue("sources", source_list); + writer.EndArray(); if (!extra_file_list.empty()) { - writer->WriteObjectKeyValue("extra", extra_file_list); + writer.Key("extra"); + writer.StartArray(); + for (const auto& extra_file : extra_file_list) { + writer.String(extra_file.data(), extra_file.size()); + } + writer.EndArray(); } - writer->EndObject(); - files_["recipe.json"] = CompiledModel::FileEntry(oss.str()); + writer.EndObject(); + + files_["recipe.json"] = CompiledModel::FileEntry(os.GetString()); } cm.files = std::move(files_); return cm; diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index d555c6d7..46603585 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -430,7 +430,7 @@ ModelBuilderImpl::CommitModelImpl(ModelImpl* out_ ModelImpl& model = *out_model; model.num_feature = this->num_feature; model.average_tree_output = this->average_tree_output; - model.task_param.output_type = TaskParameter::OutputType::kFloat; + model.task_param.output_type = TaskParam::OutputType::kFloat; model.task_param.num_class = this->num_class; // extra parameters InitParamAndCheck(&model.param, this->cfg); diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index 22eb5ad0..d17676b8 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2017-2020 by Contributors + * Copyright (c) 2017-2021 by Contributors * \file lightgbm.cc * \brief Frontend for LightGBM model * \author Hyunsu Cho @@ -11,21 +11,22 @@ #include #include #include +#include +#include namespace { -inline std::unique_ptr ParseStream(dmlc::Stream* fi); +inline std::unique_ptr ParseStream(std::istream& fi); } // anonymous namespace namespace treelite { namespace frontend { -DMLC_REGISTRY_FILE_TAG(lightgbm); -std::unique_ptr LoadLightGBMModel(const char *filename) { - std::unique_ptr fi(dmlc::Stream::Create(filename, "r")); - return ParseStream(fi.get()); +std::unique_ptr LoadLightGBMModel(const char* filename) { + std::ifstream fi(filename, std::ios::in); + return ParseStream(fi); } } // namespace frontend @@ -212,48 +213,17 @@ inline std::vector BitsetToList(const uint32_t* bits, return result; } -inline std::vector LoadText(dmlc::Stream* fi) { - const size_t bufsize = 16 * 1024 * 1024; // 16 MB - std::vector buf(bufsize); - +inline std::vector LoadText(std::istream& fi) { std::vector lines; - - size_t byte_read; - - std::string leftover = ""; // carry over between buffers - while ((byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) { - size_t i = 0; - size_t tok_begin = 0; - while (i < byte_read) { - if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines - if (tok_begin == 0 && leftover.length() + i > 0) { - // first line in buffer - lines.push_back(leftover + std::string(&buf[0], i)); - leftover = ""; - } else { - lines.emplace_back(&buf[tok_begin], i - tok_begin); - } - // skip all delimiters afterwards - for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i) {} - tok_begin = i; - } else { - ++i; - } - } - // left-over string - leftover += std::string(&buf[tok_begin], byte_read - tok_begin); - } - - if (!leftover.empty()) { - LOG(INFO) - << "Warning: input file was not terminated with end-of-line character."; - lines.push_back(leftover); + std::string line; + while (std::getline(fi, line)) { + lines.push_back(line); } return lines; } -inline std::unique_ptr ParseStream(dmlc::Stream* fi) { +inline std::unique_ptr ParseStream(std::istream& fi) { std::vector lgb_trees_; int max_feature_idx_; int num_class_; @@ -450,7 +420,7 @@ inline std::unique_ptr ParseStream(dmlc::Stream* fi) { model->task_type = treelite::TaskType::kBinaryClfRegr; model->task_param.grove_per_class = false; } - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = num_class_; model->task_param.leaf_vector_size = 1; diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index c63bcb14..3a8a1b15 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -79,7 +79,7 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( model->average_tree_output = true; model->task_type = treelite::TaskType::kBinaryClfRegr; model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = 1; model->task_param.leaf_vector_size = 1; std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); @@ -104,7 +104,7 @@ std::unique_ptr LoadSKLearnRandomForestClassifierBinary( model->average_tree_output = true; model->task_type = treelite::TaskType::kBinaryClfRegr; model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = 1; model->task_param.leaf_vector_size = 1; std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); @@ -133,7 +133,7 @@ std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( model->average_tree_output = true; model->task_type = treelite::TaskType::kMultiClfProbDistLeaf; model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = n_classes; model->task_param.leaf_vector_size = n_classes; std::strncpy(model->param.pred_transform, "identity_multiclass", @@ -183,7 +183,7 @@ std::unique_ptr LoadSKLearnGradientBoostingRegressor( model->average_tree_output = false; model->task_type = treelite::TaskType::kBinaryClfRegr; model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = 1; model->task_param.leaf_vector_size = 1; std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); @@ -209,7 +209,7 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( model->average_tree_output = false; model->task_type = treelite::TaskType::kBinaryClfRegr; model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = 1; model->task_param.leaf_vector_size = 1; std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); @@ -235,7 +235,7 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass model->average_tree_output = false; model->task_type = treelite::TaskType::kMultiClfGrovePerClass; model->task_param.grove_per_class = true; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = n_classes; model->task_param.leaf_vector_size = 1; std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform)); diff --git a/src/frontend/xgboost.cc b/src/frontend/xgboost.cc index 9feabe4b..df7926da 100644 --- a/src/frontend/xgboost.cc +++ b/src/frontend/xgboost.cc @@ -1,39 +1,38 @@ /*! - * Copyright (c) 2017-2020 by Contributors + * Copyright (c) 2017-2021 by Contributors * \file xgboost.cc * \brief Frontend for xgboost model * \author Hyunsu Cho */ #include "xgboost/xgboost.h" -#include -#include #include #include #include #include #include +#include +#include #include +#include namespace { -inline std::unique_ptr ParseStream(dmlc::Stream* fi); +inline std::unique_ptr ParseStream(std::istream& fi); } // anonymous namespace namespace treelite { namespace frontend { -DMLC_REGISTRY_FILE_TAG(xgboost); - std::unique_ptr LoadXGBoostModel(const char* filename) { - std::unique_ptr fi(dmlc::Stream::Create(filename, "r")); - return ParseStream(fi.get()); + std::ifstream fi(filename, std::ios::in | std::ios::binary); + return ParseStream(fi); } std::unique_ptr LoadXGBoostModel(const void* buf, size_t len) { - dmlc::MemoryFixedSizeStream fs(const_cast(buf), len); - return ParseStream(&fs); + std::istringstream fi(std::string(static_cast(buf), len)); + return ParseStream(fi); } } // namespace frontend @@ -49,7 +48,7 @@ class PeekableInputStream { public: const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes - explicit PeekableInputStream(dmlc::Stream* fi) + explicit PeekableInputStream(std::istream& fi) : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {} inline size_t Read(void* ptr, size_t size) { @@ -77,8 +76,8 @@ class PeekableInputStream { bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1); } begin_ptr_ = end_ptr_; - return bytes_buffered - + istm_->Read(cptr + bytes_buffered, bytes_to_read); + istm_.read(cptr + bytes_buffered, bytes_to_read); + return bytes_buffered + istm_.gcount(); } } @@ -92,15 +91,16 @@ class PeekableInputStream { if (size > bytes_buffered) { const size_t bytes_to_read = size - bytes_buffered; if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) { - CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read) + istm_.read(&buf_[end_ptr_], bytes_to_read); + CHECK_EQ(istm_.gcount(), bytes_to_read) << "Failed to peek " << size << " bytes"; end_ptr_ += bytes_to_read; } else { - CHECK_EQ(istm_->Read(&buf_[end_ptr_], - MAX_PEEK_WINDOW + 1 - end_ptr_) - + istm_->Read(&buf_[0], - bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1), - bytes_to_read) + istm_.read(&buf_[end_ptr_], MAX_PEEK_WINDOW + 1 - end_ptr_); + size_t first_read = istm_.gcount(); + istm_.read(&buf_[0], bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1); + size_t second_read = istm_.gcount(); + CHECK_EQ(first_read + second_read, bytes_to_read) << "Ill-formed XGBoost model: Failed to peek " << size << " bytes"; end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1; } @@ -117,7 +117,7 @@ class PeekableInputStream { } private: - dmlc::Stream* istm_; + std::istream& istm_; std::vector buf_; size_t begin_ptr_, end_ptr_; @@ -298,14 +298,14 @@ class XGBTree { inline void Load(PeekableInputStream* fi) { CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)) << "Ill-formed XGBoost model file: can't read TreeParam"; + CHECK_GT(param.num_nodes, 0) + << "Ill-formed XGBoost model file: a tree can't be empty"; nodes.resize(param.num_nodes); stats.resize(param.num_nodes); - CHECK_NE(param.num_nodes, 0) - << "Ill-formed XGBoost model file: a tree can't be empty"; - CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()), + CHECK_EQ(fi->Read(nodes.data(), sizeof(Node) * nodes.size()), sizeof(Node) * nodes.size()) << "Ill-formed XGBoost model file: cannot read specified number of nodes"; - CHECK_EQ(fi->Read(dmlc::BeginPtr(stats), sizeof(NodeStat) * stats.size()), + CHECK_EQ(fi->Read(stats.data(), sizeof(NodeStat) * stats.size()), sizeof(NodeStat) * stats.size()) << "Ill-formed XGBoost model file: cannot read specified number of nodes"; if (param.size_leaf_vector != 0) { @@ -322,7 +322,7 @@ class XGBTree { } }; -inline std::unique_ptr ParseStream(dmlc::Stream* fi) { +inline std::unique_ptr ParseStream(std::istream& fi) { std::vector xgb_trees_; LearnerModelParam mparam_; // model parameter GBTreeModelParam gbm_param_; // GBTree training parameter @@ -373,6 +373,8 @@ inline std::unique_ptr ParseStream(dmlc::Stream* fi) { CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_)) << "Invalid XGBoost model file: corrupted GBTree parameters"; + CHECK_GE(gbm_param_.num_trees, 0) + << "Invalid XGBoost model file: num_trees must be 0 or greater"; for (int i = 0; i < gbm_param_.num_trees; ++i) { xgb_trees_.emplace_back(); xgb_trees_.back().Load(fp.get()); @@ -381,16 +383,21 @@ inline std::unique_ptr ParseStream(dmlc::Stream* fi) { // tree_info is currently unused. std::vector tree_info; tree_info.resize(gbm_param_.num_trees); - if (gbm_param_.num_trees != 0) { - CHECK_EQ(fp->Read(dmlc::BeginPtr(tree_info), sizeof(int32_t) * tree_info.size()), + if (gbm_param_.num_trees > 0) { + CHECK_EQ(fp->Read(tree_info.data(), sizeof(int32_t) * tree_info.size()), sizeof(int32_t) * tree_info.size()); } // Load weight drop values (per tree) for dart models. std::vector weight_drop; if (name_gbm_ == "dart") { weight_drop.resize(gbm_param_.num_trees); + uint64_t sz; + fi.read(reinterpret_cast(&sz), sizeof(uint64_t)); + CHECK_EQ(sz, gbm_param_.num_trees); if (gbm_param_.num_trees != 0) { - fi->Read(&weight_drop); + for (uint64_t i = 0; i < sz; ++i) { + fi.read(reinterpret_cast(&weight_drop[i]), sizeof(bst_float)); + } } } @@ -409,7 +416,7 @@ inline std::unique_ptr ParseStream(dmlc::Stream* fi) { model->task_type = treelite::TaskType::kBinaryClfRegr; model->task_param.grove_per_class = false; } - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.output_type = treelite::TaskParam::OutputType::kFloat; model->task_param.num_class = num_class; model->task_param.leaf_vector_size = 1; diff --git a/src/frontend/xgboost_json.cc b/src/frontend/xgboost_json.cc index aa428027..af36ab31 100644 --- a/src/frontend/xgboost_json.cc +++ b/src/frontend/xgboost_json.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2020 by Contributors + * Copyright (c) 2020-2021 by Contributors * \file xgboost_json.cc * \brief Frontend for xgboost model * \author Hyunsu Cho @@ -9,7 +9,6 @@ #include "xgboost/xgboost_json.h" #include -#include #include #include #include @@ -443,7 +442,7 @@ bool XGBoostModelHandler::EndObject(std::size_t memberCount) { return false; } output.model->average_tree_output = false; - output.model->task_param.output_type = TaskParameter::OutputType::kFloat; + output.model->task_param.output_type = TaskParam::OutputType::kFloat; output.model->task_param.leaf_vector_size = 1; if (output.model->task_param.num_class > 1) { // multi-class classifier diff --git a/src/gtil/predict.cc b/src/gtil/predict.cc index 51266b77..f066dcd7 100644 --- a/src/gtil/predict.cc +++ b/src/gtil/predict.cc @@ -69,7 +69,7 @@ inline std::size_t PredictImplInner(const treelite::ModelImplGetNumRow(); const size_t num_col = input->GetNumCol(); std::vector row(num_col); - const treelite::TaskParameter task_param = model.task_param; + const treelite::TaskParam task_param = model.task_param; std::vector sum(task_param.num_class); // TODO(phcho): Use parallelism @@ -143,7 +143,7 @@ inline std::size_t PredictImpl(const treelite::ModelImpl; - const treelite::TaskParameter task_param = model.task_param; + const treelite::TaskParam task_param = model.task_param; if (task_param.num_class > 1) { if (task_param.leaf_vector_size > 1) { // multi-class classification with random forest diff --git a/src/json_serializer.cc b/src/json_serializer.cc new file mode 100644 index 00000000..05bc9911 --- /dev/null +++ b/src/json_serializer.cc @@ -0,0 +1,174 @@ +/*! + * Copyright (c) 2020-2021 by Contributors + * \file json_serializer.cc + * \brief Reference serializer implementation, which serializes to JSON. This is useful for testing + * correctness of the binary serializer + * \author Hyunsu Cho + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +void WriteElement(WriterType& writer, double e) { + writer.Double(e); +} + +template +void WriteNode(WriterType& writer, + const typename treelite::Tree::Node& node) { + writer.StartObject(); + + writer.Key("cleft"); + writer.Int(node.cleft_); + writer.Key("cright"); + writer.Int(node.cright_); + writer.Key("split_index"); + writer.Uint(node.sindex_ & ((1U << 31U) - 1U)); + writer.Key("default_left"); + writer.Bool((node.sindex_ >> 31U) != 0); + if (node.cleft_ == -1) { + writer.Key("leaf_value"); + writer.Double(node.info_.leaf_value); + } else { + writer.Key("threshold"); + writer.Double(node.info_.threshold); + } + if (node.data_count_present_) { + writer.Key("data_count"); + writer.Uint64(node.data_count_); + } + if (node.sum_hess_present_) { + writer.Key("sum_hess"); + writer.Double(node.sum_hess_); + } + if (node.gain_present_) { + writer.Key("gain"); + writer.Double(node.gain_); + } + writer.Key("split_type"); + writer.Int(static_cast(node.split_type_)); + writer.Key("cmp"); + writer.Int(static_cast(node.cmp_)); + writer.Key("categories_list_right_child"); + writer.Bool(node.categories_list_right_child_); + + writer.EndObject(); +} + +template +void WriteContiguousArray(WriterType& writer, + const treelite::ContiguousArray& array) { + writer.StartArray(); + for (std::size_t i = 0; i < array.Size(); ++i) { + WriteElement(writer, array[i]); + } + writer.EndArray(); +} + +template +void SerializeTaskParamToJSON(WriterType& writer, treelite::TaskParam task_param) { + writer.StartObject(); + + writer.Key("output_type"); + writer.Uint(static_cast(task_param.output_type)); + writer.Key("grove_per_class"); + writer.Bool(task_param.grove_per_class); + writer.Key("num_class"); + writer.Uint(task_param.num_class); + writer.Key("leaf_vector_size"); + writer.Uint(task_param.leaf_vector_size); + + writer.EndObject(); +} + +template +void SerializeModelParamToJSON(WriterType& writer, treelite::ModelParam model_param) { + writer.StartObject(); + + writer.Key("pred_transform"); + std::string pred_transform(model_param.pred_transform); + writer.String(pred_transform.data(), pred_transform.size()); + writer.Key("sigmoid_alpha"); + writer.Double(model_param.sigmoid_alpha); + writer.Key("global_bias"); + writer.Double(model_param.global_bias); + + writer.EndObject(); +} + +} // anonymous namespace + +namespace treelite { + +template +void SerializeTreeToJSON(WriterType& writer, const Tree& tree) { + writer.StartObject(); + + writer.Key("num_nodes"); + writer.Int(tree.num_nodes); + writer.Key("leaf_vector"); + WriteContiguousArray(writer, tree.leaf_vector_); + writer.Key("leaf_vector_offset"); + WriteContiguousArray(writer, tree.leaf_vector_offset_); + writer.Key("matching_categories"); + WriteContiguousArray(writer, tree.matching_categories_); + writer.Key("matching_categories_offset"); + WriteContiguousArray(writer, tree.matching_categories_offset_); + writer.Key("nodes"); + writer.StartArray(); + for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) { + WriteNode(writer, tree.nodes_[i]); + } + writer.EndArray(); + + writer.EndObject(); + + // Sanity check + CHECK_EQ(tree.nodes_.Size(), tree.num_nodes); + CHECK_EQ(tree.nodes_.Size() + 1, tree.leaf_vector_offset_.Size()); + CHECK_EQ(tree.leaf_vector_offset_.Back(), tree.leaf_vector_.Size()); + CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size()); + CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size()); +} + +template +void ModelImpl::SerializeToJSON(std::ostream& fo) const { + rapidjson::OStreamWrapper os(fo); + rapidjson::Writer writer(os); + + writer.StartObject(); + + writer.Key("num_feature"); + writer.Int(num_feature); + writer.Key("task_type"); + writer.Uint(static_cast(task_type)); + writer.Key("average_tree_output"); + writer.Bool(average_tree_output); + writer.Key("task_param"); + SerializeTaskParamToJSON(writer, task_param); + writer.Key("model_param"); + SerializeModelParamToJSON(writer, param); + writer.Key("trees"); + writer.StartArray(); + for (const Tree& tree : trees) { + SerializeTreeToJSON(writer, tree); + } + writer.EndArray(); + + writer.EndObject(); +} + +template void ModelImpl::SerializeToJSON(std::ostream& fo) const; +template void ModelImpl::SerializeToJSON(std::ostream& fo) const; +template void ModelImpl::SerializeToJSON(std::ostream& fo) const; +template void ModelImpl::SerializeToJSON(std::ostream& fo) const; + +} // namespace treelite diff --git a/src/reference_serializer.cc b/src/reference_serializer.cc deleted file mode 100644 index 4e51b630..00000000 --- a/src/reference_serializer.cc +++ /dev/null @@ -1,82 +0,0 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file reference_serializer.cc - * \brief Reference serializer implementation - * \author Hyunsu Cho - */ - -#include -#include -#include - -namespace dmlc { -namespace serializer { - -template -struct Handler> { - inline static void Write(Stream* strm, const treelite::ContiguousArray& data) { - uint64_t sz = static_cast(data.Size()); - strm->Write(sz); - strm->Write(data.Data(), sz * sizeof(T)); - } - - inline static bool Read(Stream* strm, treelite::ContiguousArray* data) { - uint64_t sz; - bool status = strm->Read(&sz); - if (!status) { - return false; - } - data->Resize(sz); - return strm->Read(data->Data(), sz * sizeof(T)); - } -}; - -} // namespace serializer -} // namespace dmlc - -namespace treelite { - -template -void Tree::ReferenceSerialize(dmlc::Stream* fo) const { - fo->Write(num_nodes); - fo->Write(leaf_vector_); - fo->Write(leaf_vector_offset_); - fo->Write(matching_categories_); - fo->Write(matching_categories_offset_); - uint64_t sz = static_cast(nodes_.Size()); - fo->Write(sz); - fo->Write(nodes_.Data(), sz * sizeof(Tree::Node)); - - // Sanity check - CHECK_EQ(nodes_.Size(), num_nodes); - CHECK_EQ(nodes_.Size() + 1, leaf_vector_offset_.Size()); - CHECK_EQ(leaf_vector_offset_.Back(), leaf_vector_.Size()); - CHECK_EQ(nodes_.Size() + 1, matching_categories_offset_.Size()); - CHECK_EQ(matching_categories_offset_.Back(), matching_categories_.Size()); -} - -template -void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const { - fo->Write(num_feature); - fo->Write(static_cast(task_type)); - fo->Write(average_tree_output); - fo->Write(&task_param, sizeof(task_param)); - fo->Write(¶m, sizeof(param)); - uint64_t sz = static_cast(trees.size()); - fo->Write(sz); - for (const Tree& tree : trees) { - tree.ReferenceSerialize(fo); - } -} - -template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; -template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; -template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; -template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; - -template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; -template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; -template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; -template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; - -} // namespace treelite diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 257c8ce2..065abdf3 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include #include @@ -16,11 +16,9 @@ namespace { inline std::string TreeliteToBytes(treelite::Model* model) { - std::string s; - std::unique_ptr mstrm{new dmlc::MemoryStringStream(&s)}; - model->ReferenceSerialize(mstrm.get()); - mstrm.reset(); - return s; + std::ostringstream oss; + model->SerializeToJSON(oss); + return oss.str(); } inline void TestRoundTrip(treelite::Model* model) { @@ -220,7 +218,7 @@ template void PyBufferInterfaceRoundTrip_DeepFullTree() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); - const int depth = 17; + const int depth = 12; std::unique_ptr builder{ new frontend::ModelBuilder(3, 1, false, threshold_type, leaf_output_type)