Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the use of dmlc::JSONWriter, dmlc::Stream, and dmlc serializer #289

Merged
merged 4 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions include/treelite/annotator.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2017-2020 by Contributors
* Copyright (c) 2017-2021 by Contributors
* \file annotator.h
* \author Hyunsu Cho
* \brief Branch annotation tools
Expand All @@ -9,7 +9,11 @@

#include <treelite/tree.h>
#include <treelite/data.h>
#include <istream>
#include <ostream>
#include <vector>
#include <cstdio>
#include <cstdint>

namespace treelite {

Expand All @@ -29,12 +33,12 @@ class BranchAnnotator {
* \brief load branch annotation from a JSON file
* \param fi input stream
*/
void Load(dmlc::Stream* fi);
void Load(std::istream& fi);
/*!
* \brief save branch annotation to a JSON file
* \param fo output stream
*/
void Save(dmlc::Stream* fo) const;
void Save(std::ostream& fo) const;
/*!
* \brief fetch branch annotation.
* Usage example:
Expand All @@ -48,12 +52,12 @@ class BranchAnnotator {
* \endcode
* \return branch annotation in 2D vector
*/
inline std::vector<std::vector<size_t>> Get() const {
return counts;
inline std::vector<std::vector<uint64_t>> Get() const {
return counts_;
}

private:
std::vector<std::vector<size_t>> counts;
std::vector<std::vector<uint64_t>> counts_;
};

} // namespace treelite
Expand Down
24 changes: 9 additions & 15 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <algorithm>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <vector>
#include <utility>
Expand All @@ -27,14 +28,6 @@

#define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256

/* Foward declarations */
namespace dmlc {

class Stream;
float stof(const std::string& value, std::size_t* pos);

} // namespace dmlc

namespace treelite {

// Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
Expand Down Expand Up @@ -161,7 +154,7 @@ enum class TaskType : uint8_t {
};

/*! \brief Group of parameters that are dependent on the choice of the task type. */
struct TaskParameter {
struct TaskParam {
enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
/*! \brief The type of output from each leaf node. */
OutputType output_type;
Expand Down Expand Up @@ -190,7 +183,7 @@ struct TaskParameter {
unsigned int leaf_vector_size;
};

static_assert(std::is_pod<TaskParameter>::value, "TaskParameter must be POD type");
static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");

/*! \brief in-memory representation of a decision tree */
template <typename ThresholdType, typename LeafOutputType>
Expand Down Expand Up @@ -289,6 +282,9 @@ class Tree {
ContiguousArray<uint32_t> matching_categories_;
ContiguousArray<std::size_t> matching_categories_offset_;

template <typename WriterType, typename X, typename Y>
friend void SerializeTreeToJSON(WriterType& writer, const Tree<X, Y>& tree);

// allocate a new node
inline int AllocNode();

Expand Down Expand Up @@ -562,8 +558,6 @@ class Tree {
node.gain_ = gain;
node.gain_present_ = true;
}

void ReferenceSerialize(dmlc::Stream* fo) const;
};

struct ModelParam {
Expand Down Expand Up @@ -656,7 +650,7 @@ class Model {

virtual std::size_t GetNumTree() const = 0;
virtual void SetTreeLimit(std::size_t limit) = 0;
virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0;
virtual void SerializeToJSON(std::ostream& fo) const = 0;

/* In-memory serialization, zero-copy */
std::vector<PyBufferFrame> GetPyBuffer();
Expand All @@ -676,7 +670,7 @@ class Model {
/*! \brief whether to average tree outputs */
bool average_tree_output;
/*! \brief Group of parameters that are specific to the particular task type */
TaskParameter task_param;
TaskParam task_param;
/*! \brief extra parameters */
ModelParam param;

Expand Down Expand Up @@ -712,7 +706,7 @@ class ModelImpl : public Model {
ModelImpl(ModelImpl&&) noexcept = default;
ModelImpl& operator=(ModelImpl&&) noexcept = default;

void ReferenceSerialize(dmlc::Stream* fo) const override;
void SerializeToJSON(std::ostream& fo) const override;
inline std::size_t GetNumTree() const override {
return trees.size();
}
Expand Down
4 changes: 2 additions & 2 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ ModelParam::InitAllowUnknown(const Container& kwargs) {
TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
} else if (e.first == "sigmoid_alpha") {
this->sigmoid_alpha = dmlc::stof(e.second, nullptr);
this->sigmoid_alpha = std::stof(e.second, nullptr);
} else if (e.first == "global_bias") {
this->global_bias = dmlc::stof(e.second, nullptr);
this->global_bias = std::stof(e.second, nullptr);
}
}
return unknowns;
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ target_sources(objtreelite
filesystem.cc
optable.cc
serializer.cc
reference_serializer.cc
json_serializer.cc
${PROJECT_SOURCE_DIR}/include/treelite/annotator.h
${PROJECT_SOURCE_DIR}/include/treelite/base.h
${PROJECT_SOURCE_DIR}/include/treelite/c_api.h
Expand Down
66 changes: 45 additions & 21 deletions src/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include <treelite/annotator.h>
#include <treelite/math.h>
#include <treelite/omp.h>
#include <dmlc/json.h>
#include <rapidjson/istreamwrapper.h>
#include <rapidjson/ostreamwrapper.h>
#include <rapidjson/writer.h>
#include <rapidjson/document.h>
#include <dmlc/io.h>
#include <limits>
#include <cstdint>
Expand All @@ -23,7 +26,7 @@ union Entry {

template <typename ElementType, typename ThresholdType, typename LeafOutputType>
void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
const Entry<ElementType>* data, int nid, size_t* out_counts) {
const Entry<ElementType>* data, int nid, uint64_t* out_counts) {
++out_counts[nid];
if (!tree.IsLeaf(nid)) {
const unsigned split_index = tree.SplitIndex(nid);
Expand Down Expand Up @@ -58,15 +61,15 @@ void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,

template <typename ElementType, typename ThresholdType, typename LeafOutputType>
void Traverse(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
const Entry<ElementType>* data, size_t* out_counts) {
const Entry<ElementType>* data, uint64_t* out_counts) {
Traverse_(tree, data, 0, out_counts);
}

template <typename ElementType, typename ThresholdType, typename LeafOutputType>
inline void ComputeBranchLoopImpl(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DenseDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
const size_t ntree = model.trees.size();
CHECK_LE(rbegin, rend);
Expand Down Expand Up @@ -103,7 +106,7 @@ template <typename ElementType, typename ThresholdType, typename LeafOutputType>
inline void ComputeBranchLoopImpl(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::CSRDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
const size_t ntree = model.trees.size();
CHECK_LE(rbegin, rend);
Expand Down Expand Up @@ -136,7 +139,7 @@ class ComputeBranchLoopDispatcherWithDenseDMatrix {
inline static void Dispatch(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
CHECK(dmat_) << "Dangling data matrix reference detected";
ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
Expand All @@ -150,7 +153,7 @@ class ComputeBranchLoopDispatcherWithCSRDMatrix {
inline static void Dispatch(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
CHECK(dmat_) << "Dangling data matrix reference detected";
ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
Expand All @@ -161,7 +164,7 @@ template <typename ThresholdType, typename LeafOutputType>
inline void ComputeBranchLoop(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, size_t rbegin,
size_t rend, int nthread, const size_t* count_row_ptr,
size_t* counts_tloc) {
uint64_t* counts_tloc) {
switch (dmat->GetType()) {
case treelite::DMatrixType::kDense: {
treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
Expand Down Expand Up @@ -189,9 +192,9 @@ inline void
AnnotateImpl(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, int nthread, int verbose,
std::vector<std::vector<size_t>>* out_counts) {
std::vector<size_t> new_counts;
std::vector<size_t> counts_tloc;
std::vector<std::vector<uint64_t>>* out_counts) {
std::vector<uint64_t> new_counts;
std::vector<uint64_t> counts_tloc;
std::vector<size_t> count_row_ptr;

count_row_ptr = {0};
Expand Down Expand Up @@ -224,7 +227,7 @@ AnnotateImpl(
}

// change layout of counts
std::vector<std::vector<size_t>>& counts = *out_counts;
std::vector<std::vector<uint64_t>>& counts = *out_counts;
for (size_t i = 0; i < ntree; ++i) {
counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
}
Expand All @@ -234,22 +237,43 @@ void
BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) {
TypeInfo threshold_type = model.GetThresholdType();
model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) {
AnnotateImpl(handle, dmat, nthread, verbose, &this->counts);
AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_);
});
}

void
BranchAnnotator::Load(dmlc::Stream* fi) {
dmlc::istream is(fi);
std::unique_ptr<dmlc::JSONReader> reader(new dmlc::JSONReader(&is));
reader->Read(&counts);
BranchAnnotator::Load(std::istream& fi) {
rapidjson::IStreamWrapper is(fi);

rapidjson::Document doc;
doc.ParseStream(is);

std::string err_msg = "JSON file must contain a list of lists of integers";
CHECK(doc.IsArray()) << err_msg;
counts_.clear();
for (const auto& node_cnt : doc.GetArray()) {
CHECK(node_cnt.IsArray()) << err_msg;
counts_.emplace_back();
for (const auto& e : node_cnt.GetArray()) {
counts_.back().push_back(e.GetUint64());
}
}
}

void
BranchAnnotator::Save(dmlc::Stream* fo) const {
dmlc::ostream os(fo);
std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&os));
writer->Write(counts);
BranchAnnotator::Save(std::ostream& fo) const {
rapidjson::OStreamWrapper os(fo);
rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);

writer.StartArray();
for (const auto& node_cnt : counts_) {
writer.StartArray();
for (auto e : node_cnt) {
writer.Uint64(e);
}
writer.EndArray();
}
writer.EndArray();
}

} // namespace treelite
6 changes: 4 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <dmlc/io.h>
#include <memory>
#include <algorithm>
#include <fstream>
#include <cstdio>

using namespace treelite;

Expand Down Expand Up @@ -51,8 +53,8 @@ int TreeliteAnnotationSave(AnnotationHandle handle,
const char* path) {
API_BEGIN();
const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path, "w"));
annotator->Save(fo.get());
std::ofstream fo(path);
annotator->Save(fo);
API_END();
}

Expand Down
5 changes: 3 additions & 2 deletions src/compiler/ast/ast.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2017-2020 by Contributors
* Copyright (c) 2017-2021 by Contributors
* \file ast.h
* \brief Definition for AST classes
* \author Hyunsu Cho
Expand All @@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include <utility>
#include <cstdint>

namespace treelite {
namespace compiler {
Expand All @@ -24,7 +25,7 @@ class ASTNode {
std::vector<ASTNode*> children;
int node_id;
int tree_id;
dmlc::optional<size_t> data_count;
dmlc::optional<uint64_t> data_count;
dmlc::optional<double> sum_hess;
virtual std::string GetDump() const = 0;
virtual ~ASTNode() = 0; // force ASTNode to be abstract class
Expand Down
3 changes: 2 additions & 1 deletion src/compiler/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ostream>
#include <utility>
#include <memory>
#include <cstdint>
#include "./ast.h"

namespace treelite {
Expand Down Expand Up @@ -58,7 +59,7 @@ class ASTBuilder {
/* \brief replace split thresholds with integers */
void QuantizeThresholds();
/* \brief Load data counts from annotation file */
void LoadDataCounts(const std::vector<std::vector<size_t>>& counts);
void LoadDataCounts(const std::vector<std::vector<uint64_t>>& counts);
/*
* \brief Get a text representation of AST
*/
Expand Down
Loading