Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JSON schema to model dump. #5660

Merged
merged 2 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions doc/dump.schema
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"definitions": {
"split_node": {
"type": "object",
"properties": {
"nodeid": {
"type": "number",
"minimum": 0
},
"depth": {
"type": "number",
"minimum": 0
},
"yes": {
"type": "number",
"minimum": 0
},
"no": {
"type": "number",
"minimum": 0
},
"split": {
"type": "string"
},
"children": {
"type": "array",
"items": {
"oneOf": [
{"$ref": "#/definitions/split_node"},
{"$ref": "#/definitions/leaf_node"}
]
},
"maxItems": 2
}
},
"required": ["nodeid", "depth", "yes", "no", "split", "children"]
},
"leaf_node": {
"type": "object",
"properties": {
"nodeid": {
"type": "number",
"minimum": 0
},
"leaf": {
"type": "number"
}
},
"required": ["nodeid", "leaf"]
}
},
"type": "object",
"$ref": "#/definitions/split_node"
}
51 changes: 26 additions & 25 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,20 @@ class TreeGenerator {
return result;
}

virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const {
return "";
}
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const {
return "";
}
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const {
return "";
}
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
virtual std::string NodeStat(RegTree const& tree, int32_t nid) const {
return "";
}

virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;

virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex();
Expand Down Expand Up @@ -110,7 +110,7 @@ class TreeGenerator {
return result;
}

virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;

public:
Expand Down Expand Up @@ -181,7 +181,7 @@ class TextGenerator : public TreeGenerator {
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
static std::string kStatTemplate = ",cover={cover}";
std::string result = SuperT::Match(
Expand All @@ -195,7 +195,7 @@ class TextGenerator : public TreeGenerator {
return result;
}

std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
Expand All @@ -211,7 +211,7 @@ class TextGenerator : public TreeGenerator {

std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) {
std::string cond, uint32_t depth) const {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
Expand All @@ -226,7 +226,7 @@ class TextGenerator : public TreeGenerator {
return result;
}

std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kIntegerTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
Expand All @@ -238,21 +238,21 @@ class TextGenerator : public TreeGenerator {
std::to_string(integer_threshold), depth);
}

std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}

std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}

std::string NodeStat(RegTree const& tree, int32_t nid) override {
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
Expand Down Expand Up @@ -297,15 +297,15 @@ class JsonGenerator : public TreeGenerator {
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}

std::string Indent(uint32_t depth) {
std::string Indent(uint32_t depth) const {
std::string result;
for (uint32_t i = 0; i < depth + 1; ++i) {
result += " ";
}
return result;
}

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kLeafTemplate =
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
static std::string const kStatTemplate =
Expand All @@ -321,11 +321,11 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
static std::string const kIndicatorTemplate =
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
auto split_index = tree[nid].SplitIndex();
auto result = SuperT::Match(
kIndicatorTemplate,
Expand All @@ -337,8 +337,9 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
std::string const& template_str, std::string cond, uint32_t depth) {
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
std::string const &template_str, std::string cond,
uint32_t depth) const {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
Expand All @@ -353,7 +354,7 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
Expand All @@ -367,7 +368,7 @@ class JsonGenerator : public TreeGenerator {
std::to_string(integer_threshold), depth);
}

std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kQuantitiveTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
Expand All @@ -376,7 +377,7 @@ class JsonGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}

std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
Expand All @@ -385,7 +386,7 @@ class JsonGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}

std::string NodeStat(RegTree const& tree, int32_t nid) override {
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
Expand Down Expand Up @@ -529,7 +530,7 @@ class GraphvizGenerator : public TreeGenerator {
protected:
// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto split = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
Expand Down Expand Up @@ -563,7 +564,7 @@ class GraphvizGenerator : public TreeGenerator {
return result;
};

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ TEST(Tree, DumpJson) {

str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos);


auto j_tree = Json::Load({str.c_str(), str.size()});
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
}

TEST(Tree, DumpText) {
Expand Down
34 changes: 33 additions & 1 deletion tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_model_json_io(self):
assert locale.getpreferredencoding(False) == loc

@pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self):
def test_json_io_schema(self):
import jsonschema
model_path = 'test_json_schema.json'
path = os.path.dirname(
Expand All @@ -342,3 +342,35 @@ def test_json_schema(self):
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)

@pytest.mark.skipif(**tm.no_json_schema())
def test_json_dump_schema(self):
import jsonschema

def validate_model(parameters):
X = np.random.random((100, 30))
y = np.random.randint(0, 4, size=(100,))

parameters['num_class'] = 4
m = xgb.DMatrix(X, y)

booster = xgb.train(parameters, m)
dump = booster.get_dump(dump_format='json')

for i in range(len(dump)):
jsonschema.validate(instance=json.loads(dump[i]),
schema=schema)

path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'dump.schema')
with open(doc, 'r') as fd:
schema = json.load(fd)

parameters = {'tree_method': 'hist', 'booster': 'gbtree',
'objective': 'multi:softmax'}
validate_model(parameters)

parameters = {'tree_method': 'hist', 'booster': 'dart',
'objective': 'multi:softmax'}
validate_model(parameters)