Skip to content

Commit

Permalink
Add a test to prepare for integer default_left (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 authored Jan 13, 2022
1 parent 9b7be0f commit c1c65f8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/frontend/xgboost/xgboost_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,18 @@ class RootHandler : public OutputHandler<std::unique_ptr<treelite::Model>> {
class DelegatedHandler
: public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
public Delegator {

public:
/*! \brief create DelegatedHandler with initial RootHandler on stack */
static std::shared_ptr<DelegatedHandler> create() {
/*! \brief create DelegatedHandler with empty stack */
static std::shared_ptr<DelegatedHandler> create_empty() {
struct make_shared_enabler : public DelegatedHandler {};

std::shared_ptr<DelegatedHandler> new_handler =
std::make_shared<make_shared_enabler>();
return new_handler;
}

/*! \brief create DelegatedHandler with initial RootHandler on stack */
static std::shared_ptr<DelegatedHandler> create() {
std::shared_ptr<DelegatedHandler> new_handler = create_empty();
new_handler->push_delegate(std::make_shared<RootHandler>(
new_handler,
new_handler->result));
Expand Down
61 changes: 61 additions & 0 deletions tests/cpp/test_frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <rapidjson/document.h>
#include <fmt/format.h>
#include <treelite/tree.h>
#include <treelite/frontend.h>
#include "xgboost/xgboost_json.h"

using namespace fmt::literals;

namespace treelite {

class MockDelegator : public details::Delegator {
Expand Down Expand Up @@ -367,6 +370,64 @@ TEST(RegTreeHandlerSuite, RegTreeHandler) {
reader.Parse(input_stream, handler);
}

TEST(RegTreeHandlerSuite, DefaultLeft) {
class RegTreeHandlerWrapper : public details::BaseHandler {
public:
using BaseHandler::BaseHandler;
bool StartObject() override {
push_handler<details::RegTreeHandler, Tree<float, float>>(output);
return true;
}
Tree<float, float> output;
};
auto handler = details::DelegatedHandler::create_empty();
auto wrapped_handler = std::make_shared<RegTreeHandlerWrapper>(handler);
handler->push_delegate(wrapped_handler);
rapidjson::Reader reader;

auto gather_default_left_array = [](const Tree<float, float>& tree) {
std::vector<bool> default_left;
for (int nid = 0; nid < tree.num_nodes; ++nid) {
default_left.push_back(tree.DefaultLeft(nid));
}
return default_left;
};

// default_left array can be integers; Treelite should automatically cast
// the elements to booleans.
std::string json_str_fmt = R"JSON(
{{"base_weights": [1.0, 0.0, 2.0, -0.5, 0.5, 1.5, 2.5],
"default_left": [{default_left}],
"id": 0,
"left_children": [1, 3, 5, -1, -1, -1, -1],
"loss_changes": [4.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0],
"parents": [2147483647, 0, 0, 1, 1, 2, 2],
"right_children": [2, 4, 6, -1, -1, -1, -1],
"split_conditions": [2.0, 1.0, 3.0, -0.5, 0.5, 1.5, 2.5],
"split_indices": [0, 0, 0, 0, 0, 0, 0],
"split_type": [0, 0, 0, 0, 0, 0, 0],
"sum_hessian": [4.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0],
"tree_param": {{
"num_deleted": "0",
"num_feature": "1",
"num_nodes": "7",
"size_leaf_vector": "0"}}
}})JSON";
std::string json_str = fmt::format(json_str_fmt,
"default_left"_a = fmt::join(std::vector<int>{1, 1, 1, 0, 0, 0, 0}, ", "));
auto input_stream = rapidjson::MemoryStream(json_str.c_str(), json_str.size());
ASSERT_TRUE(reader.Parse(input_stream, *handler));
std::vector<bool> expected_default_left{true, true, true, false, false, false, false};
ASSERT_EQ(gather_default_left_array(wrapped_handler->output), expected_default_left);

json_str = fmt::format(json_str_fmt,
"default_left"_a = fmt::join(
std::vector<bool>{true, true, true, false, false, false, false}, ", "));
input_stream = rapidjson::MemoryStream(json_str.c_str(), json_str.size());
ASSERT_TRUE(reader.Parse(input_stream, *handler));
ASSERT_EQ(gather_default_left_array(wrapped_handler->output), expected_default_left);
}

/******************************************************************************
* GBTreeModelHandler
* ***************************************************************************/
Expand Down

0 comments on commit c1c65f8

Please sign in to comment.