From 1835f92e4d7ee2ead5f1c8a6c39902e76f834d61 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 13 Jan 2022 15:57:24 +0000 Subject: [PATCH] Add a test to prepare for integer default_left --- src/frontend/xgboost/xgboost_json.h | 12 ++++-- tests/cpp/test_frontend.cc | 61 +++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/src/frontend/xgboost/xgboost_json.h b/src/frontend/xgboost/xgboost_json.h index 501af8a4..ee1e9f64 100644 --- a/src/frontend/xgboost/xgboost_json.h +++ b/src/frontend/xgboost/xgboost_json.h @@ -349,14 +349,18 @@ class RootHandler : public OutputHandler> { class DelegatedHandler : public rapidjson::BaseReaderHandler, DelegatedHandler>, public Delegator { - public: - /*! \brief create DelegatedHandler with initial RootHandler on stack */ - static std::shared_ptr create() { + /*! \brief create DelegatedHandler with empty stack */ + static std::shared_ptr create_empty() { struct make_shared_enabler : public DelegatedHandler {}; - std::shared_ptr new_handler = std::make_shared(); + return new_handler; + } + + /*! \brief create DelegatedHandler with initial RootHandler on stack */ + static std::shared_ptr create() { + std::shared_ptr new_handler = create_empty(); new_handler->push_delegate(std::make_shared( new_handler, new_handler->result)); diff --git a/tests/cpp/test_frontend.cc b/tests/cpp/test_frontend.cc index 92f2495b..af508e65 100644 --- a/tests/cpp/test_frontend.cc +++ b/tests/cpp/test_frontend.cc @@ -7,10 +7,13 @@ #include #include #include +#include #include #include #include "xgboost/xgboost_json.h" +using namespace fmt::literals; + namespace treelite { class MockDelegator : public details::Delegator { @@ -367,6 +370,64 @@ TEST(RegTreeHandlerSuite, RegTreeHandler) { reader.Parse(input_stream, handler); } +TEST(RegTreeHandlerSuite, DefaultLeft) { + class RegTreeHandlerWrapper : public details::BaseHandler { + public: + using BaseHandler::BaseHandler; + bool StartObject() override { + push_handler>(output); + return true; + } + Tree output; + }; + auto handler = details::DelegatedHandler::create_empty(); + auto wrapped_handler = std::make_shared(handler); + handler->push_delegate(wrapped_handler); + rapidjson::Reader reader; + + auto gather_default_left_array = [](const Tree& tree) { + std::vector default_left; + for (int nid = 0; nid < tree.num_nodes; ++nid) { + default_left.push_back(tree.DefaultLeft(nid)); + } + return default_left; + }; + + // default_left array can be integers; Treelite should automatically cast + // the elements to booleans. + std::string json_str_fmt = R"JSON( + {{"base_weights": [1.0, 0.0, 2.0, -0.5, 0.5, 1.5, 2.5], + "default_left": [{default_left}], + "id": 0, + "left_children": [1, 3, 5, -1, -1, -1, -1], + "loss_changes": [4.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0], + "parents": [2147483647, 0, 0, 1, 1, 2, 2], + "right_children": [2, 4, 6, -1, -1, -1, -1], + "split_conditions": [2.0, 1.0, 3.0, -0.5, 0.5, 1.5, 2.5], + "split_indices": [0, 0, 0, 0, 0, 0, 0], + "split_type": [0, 0, 0, 0, 0, 0, 0], + "sum_hessian": [4.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0], + "tree_param": {{ + "num_deleted": "0", + "num_feature": "1", + "num_nodes": "7", + "size_leaf_vector": "0"}} + }})JSON"; + std::string json_str = fmt::format(json_str_fmt, + "default_left"_a = fmt::join(std::vector{1, 1, 1, 0, 0, 0, 0}, ", ")); + auto input_stream = rapidjson::MemoryStream(json_str.c_str(), json_str.size()); + ASSERT_TRUE(reader.Parse(input_stream, *handler)); + std::vector expected_default_left{true, true, true, false, false, false, false}; + ASSERT_EQ(gather_default_left_array(wrapped_handler->output), expected_default_left); + + json_str = fmt::format(json_str_fmt, + "default_left"_a = fmt::join( + std::vector{true, true, true, false, false, false, false}, ", ")); + input_stream = rapidjson::MemoryStream(json_str.c_str(), json_str.size()); + ASSERT_TRUE(reader.Parse(input_stream, *handler)); + ASSERT_EQ(gather_default_left_array(wrapped_handler->output), expected_default_left); +} + /****************************************************************************** * GBTreeModelHandler * ***************************************************************************/