Skip to content

Commit

Permalink
Simplify model spec by removing missing_value_to_zero field (#234)
Browse files Browse the repository at this point in the history
* Simplify model spec by removing missing_value_to_zero field

* Fix formatting check

* Update format string in PyBuffer

* Fix an affected test

* Apply suggestions from code review

Co-authored-by: Andy Adinets <adinetz@gmail.com>

Co-authored-by: Andy Adinets <adinetz@gmail.com>
  • Loading branch information
hcho3 and canonizer authored Dec 11, 2020
1 parent 737059a commit 9fb79c3
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 122 deletions.
16 changes: 1 addition & 15 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ class Tree {
* otherwise, take the right child.
*/
Operator cmp_;
/* \brief Whether to convert missing value to zero.
* When this flag is set, it overrides the behavior of default_left().
*/
bool missing_value_to_zero_;
/*! \brief whether data_count_ field is present */
bool data_count_present_;
/*! \brief whether sum_hess_ field is present */
Expand Down Expand Up @@ -446,13 +442,6 @@ class Tree {
inline double Gain(int nid) const {
return nodes_[nid].gain_;
}
/*!
* \brief test whether missing values should be converted into zero
* \param nid ID of node being queried
*/
inline bool MissingValueToZero(int nid) const {
return nodes_[nid].missing_value_to_zero_;
}
/*!
* \brief test whether the list given by MatchingCategories(nid) is associated with the right
* child node or the left child node
Expand All @@ -469,18 +458,16 @@ class Tree {
* \param split_index feature index to split
* \param threshold threshold value
* \param default_left the default direction when feature is unknown
* \param missing_value_to_zero whether missing values should be converted into zero
* \param cmp comparison operator to compare between feature value and
* threshold
*/
inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
bool default_left, bool missing_value_to_zero, Operator cmp);
bool default_left, Operator cmp);
/*!
* \brief create a categorical split
* \param nid ID of node being updated
* \param split_index feature index to split
* \param default_left the default direction when feature is unknown
* \param missing_value_to_zero whether missing values should be converted into zero
* \param categories_list list of categories to belong to either the right child node or the left
* child node. Set categories_list_right_child parameter to indicate
* which node the category list should represent.
Expand All @@ -489,7 +476,6 @@ class Tree {
* (false)
*/
inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
bool missing_value_to_zero,
const std::vector<uint32_t>& categories_list,
bool categories_list_right_child);
/*!
Expand Down
12 changes: 4 additions & 8 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ template <typename ThresholdType, typename LeafOutputType>
inline const char*
Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
if (std::is_same<ThresholdType, float>::value) {
return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=?x}";
return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
} else {
return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?=?x}";
return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
}
}

Expand Down Expand Up @@ -445,7 +445,6 @@ inline void Tree<ThresholdType, LeafOutputType>::Node::Init() {
info_.threshold = static_cast<ThresholdType>(0);
data_count_ = 0;
sum_hess_ = gain_ = 0.0;
missing_value_to_zero_ = false;
data_count_present_ = sum_hess_present_ = gain_present_ = false;
categories_list_right_child_ = false;
split_type_ = SplitFeatureType::kNone;
Expand Down Expand Up @@ -522,8 +521,7 @@ Tree<ThresholdType, LeafOutputType>::GetCategoricalFeatures() const {
template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::SetNumericalSplit(
int nid, unsigned split_index, ThresholdType threshold, bool default_left,
bool missing_value_to_zero, Operator cmp) {
int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
Node& node = nodes_[nid];
if (split_index >= ((1U << 31U) - 1)) {
throw std::runtime_error("split_index too big");
Expand All @@ -533,14 +531,13 @@ Tree<ThresholdType, LeafOutputType>::SetNumericalSplit(
(node.info_).threshold = threshold;
node.cmp_ = cmp;
node.split_type_ = SplitFeatureType::kNumerical;
node.missing_value_to_zero_ = missing_value_to_zero;
node.categories_list_right_child_ = false;
}

template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::SetCategoricalSplit(
int nid, unsigned split_index, bool default_left, bool missing_value_to_zero,
int nid, unsigned split_index, bool default_left,
const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
if (split_index >= ((1U << 31U) - 1)) {
throw std::runtime_error("split_index too big");
Expand Down Expand Up @@ -568,7 +565,6 @@ Tree<ThresholdType, LeafOutputType>::SetCategoricalSplit(
if (default_left) split_index |= (1U << 31U);
node.sindex_ = split_index;
node.split_type_ = SplitFeatureType::kCategorical;
node.missing_value_to_zero_ = missing_value_to_zero;
node.categories_list_right_child_ = categories_list_right_child;
}

Expand Down
22 changes: 8 additions & 14 deletions src/compiler/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,19 @@ class CodeFolderNode : public ASTNode {

class ConditionNode : public ASTNode {
public:
ConditionNode(unsigned split_index, bool default_left, bool convert_missing_to_zero)
: split_index(split_index), default_left(default_left),
convert_missing_to_zero(convert_missing_to_zero) {}
ConditionNode(unsigned split_index, bool default_left)
: split_index(split_index), default_left(default_left) {}
unsigned split_index;
bool default_left;
bool convert_missing_to_zero;
dmlc::optional<double> gain;

std::string GetDump() const override {
if (gain) {
return fmt::format("ConditionNode {{ split_index: {}, default_left: {}, "
"convert_missing_to_zero: {}, gain: {} }}",
split_index, default_left, convert_missing_to_zero, gain.value());
return fmt::format("ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
split_index, default_left, gain.value());
} else {
return fmt::format("ConditionNode {{ split_index: {}, default_left: {}, "
"convert_missing_to_zero: {} }}",
split_index, default_left, convert_missing_to_zero);
return fmt::format("ConditionNode {{ split_index: {}, default_left: {} }}",
split_index, default_left);
}
}
};
Expand All @@ -135,10 +131,9 @@ template <typename ThresholdType>
class NumericalConditionNode : public ConditionNode {
public:
NumericalConditionNode(unsigned split_index, bool default_left,
bool convert_missing_to_zero,
bool quantized, Operator op,
ThresholdVariant<ThresholdType> threshold)
: ConditionNode(split_index, default_left, convert_missing_to_zero),
: ConditionNode(split_index, default_left),
quantized(quantized), op(op), threshold(threshold), zero_quantized(-1) {}
bool quantized;
Operator op;
Expand All @@ -158,10 +153,9 @@ class NumericalConditionNode : public ConditionNode {
class CategoricalConditionNode : public ConditionNode {
public:
CategoricalConditionNode(unsigned split_index, bool default_left,
bool convert_missing_to_zero,
const std::vector<uint32_t>& matching_categories,
bool categories_list_right_child)
: ConditionNode(split_index, default_left, convert_missing_to_zero),
: ConditionNode(split_index, default_left),
matching_categories(matching_categories),
categories_list_right_child(categories_list_right_child) {}
std::vector<uint32_t> matching_categories;
Expand Down
2 changes: 0 additions & 2 deletions src/compiler/ast/build.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ ASTBuilder<ThresholdType, LeafOutputType>::BuildASTFromTree(
parent,
tree.SplitIndex(nid),
tree.DefaultLeft(nid),
tree.MissingValueToZero(nid),
false,
tree.ComparisonOp(nid),
ThresholdVariant<ThresholdType>(tree.Threshold(nid)));
Expand All @@ -58,7 +57,6 @@ ASTBuilder<ThresholdType, LeafOutputType>::BuildASTFromTree(
parent,
tree.SplitIndex(nid),
tree.DefaultLeft(nid),
tree.MissingValueToZero(nid),
tree.MatchingCategories(nid),
tree.CategoriesListRightChild(nid));
}
Expand Down
76 changes: 24 additions & 52 deletions src/compiler/ast_native.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,24 +315,14 @@ class ASTNativeCompiler : public Compiler {
if ( (t = dynamic_cast<const NumericalConditionNode<ThresholdType>*>(node)) ) {
/* numerical split */
std::string condition = ExtractNumericalCondition(t);
if (node->convert_missing_to_zero) {
std::string condition_for_na = ExtractNumericalCondition(t, /* use_zero_threshold */ true);
condition_with_na_check
= fmt::format("(!(data[{split_index}].missing != -1) && {condition_for_na}) ||"
"( (data[{split_index}].missing != -1) && {condition})",
"split_index"_a = node->split_index,
"condition_for_na"_a = condition_for_na,
"condition"_a = condition);
} else {
const char* condition_with_na_check_template
= (node->default_left) ?
"!(data[{split_index}].missing != -1) || ({condition})"
: " (data[{split_index}].missing != -1) && ({condition})";
condition_with_na_check
= fmt::format(condition_with_na_check_template,
"split_index"_a = node->split_index,
"condition"_a = condition);
}
const char* condition_with_na_check_template
= (node->default_left) ?
"!(data[{split_index}].missing != -1) || ({condition})"
: " (data[{split_index}].missing != -1) && ({condition})";
condition_with_na_check
= fmt::format(condition_with_na_check_template,
"split_index"_a = node->split_index,
"condition"_a = condition);
} else { /* categorical split */
const CategoricalConditionNode* t2 = dynamic_cast<const CategoricalConditionNode*>(node);
CHECK(t2);
Expand Down Expand Up @@ -586,19 +576,13 @@ class ASTNativeCompiler : public Compiler {

template <typename ThresholdType>
inline std::string
ExtractNumericalCondition(const NumericalConditionNode<ThresholdType>* node,
bool use_zero_threshold = false) {
ExtractNumericalCondition(const NumericalConditionNode<ThresholdType>* node) {
const std::string threshold_type
= native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
std::string result;
if (node->quantized) { // quantized threshold
std::string lhs;
if (use_zero_threshold) {
lhs = fmt::format("{}", node->zero_quantized);
} else {
lhs = fmt::format("data[{split_index}].qvalue",
"split_index"_a = node->split_index);
}
std::string lhs = fmt::format("data[{split_index}].qvalue",
"split_index"_a = node->split_index);
result = fmt::format("{lhs} {opname} {threshold}",
"lhs"_a = lhs,
"opname"_a = OpName(node->op),
Expand All @@ -609,13 +593,8 @@ class ASTNativeCompiler : public Compiler {
result = (CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
? "1" : "0");
} else { // finite threshold
std::string lhs;
if (use_zero_threshold) {
lhs = fmt::format("{}", common_util::ToStringHighPrecision(static_cast<ThresholdType>(0)));
} else {
lhs = fmt::format("data[{split_index}].fvalue",
"split_index"_a = node->split_index);
}
std::string lhs = fmt::format("data[{split_index}].fvalue",
"split_index"_a = node->split_index);
result
= fmt::format("{lhs} {opname} ({threshold_type}){threshold}",
"lhs"_a = lhs,
Expand All @@ -640,26 +619,19 @@ class ASTNativeCompiler : public Compiler {
result = "0";
} else {
std::ostringstream oss;
if (node->convert_missing_to_zero) {
// All missing values are converted into zeros
const std::string right_categories_flag = (node->categories_list_right_child ? "!" : "");
if (node->default_left) {
oss << fmt::format(
"((tmp = (data[{0}].missing == -1 ? 0U "
": (unsigned int)(data[{0}].fvalue) )), ", node->split_index);
"data[{split_index}].missing == -1 || {right_categories_flag}("
"(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
"split_index"_a = node->split_index,
"right_categories_flag"_a = right_categories_flag);
} else {
const std::string right_categories_flag = (node->categories_list_right_child ? "!" : "");
if (node->default_left) {
oss << fmt::format(
"data[{split_index}].missing == -1 || {right_categories_flag}("
"(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
"split_index"_a = node->split_index,
"right_categories_flag"_a = right_categories_flag);
} else {
oss << fmt::format(
"data[{split_index}].missing != -1 && {right_categories_flag}("
"(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
"split_index"_a = node->split_index,
"right_categories_flag"_a = right_categories_flag);
}
oss << fmt::format(
"data[{split_index}].missing != -1 && {right_categories_flag}("
"(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
"split_index"_a = node->split_index,
"right_categories_flag"_a = right_categories_flag);
}
oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
<< bitmap[0] << "U >> tmp) & 1) )";
Expand Down
7 changes: 1 addition & 6 deletions src/compiler/common/code_folding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,7 @@ RenderCodeFolderArrays(const CodeFolderNode* node,
default_left = t3->default_left;
split_index = t3->split_index;
threshold = "-1"; // dummy value
CHECK(!t3->convert_missing_to_zero)
<< "Code folding not supported, because a categorical split "
<< "is supposed to convert missing values into zeros, and this "
<< "is not possible with current code folding implementation.";
std::vector<uint64_t> bitmap
= GetCategoricalBitmap(t3->matching_categories);
std::vector<uint64_t> bitmap = GetCategoricalBitmap(t3->matching_categories);
cat_bitmap.insert(cat_bitmap.end(), bitmap.begin(), bitmap.end());
cat_begin.push_back(cat_bitmap.size());
}
Expand Down
7 changes: 3 additions & 4 deletions src/frontend/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,7 @@ ModelBuilderImpl::CommitModelImpl(ModelImpl<ThresholdType, LeafOutputType>* out_
<< TypeInfoToString(TypeToInfo<ThresholdType>())
<< " Given: " << TypeInfoToString(node->threshold.GetValueType());
ThresholdType threshold = node->threshold.Get<ThresholdType>();
tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, false,
node->op);
tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op);
Q.push({node->left_child, tree.LeftChild(nid)});
Q.push({node->right_child, tree.RightChild(nid)});
} else if (node->status == NodeDraft::Status::kCategoricalTest) {
Expand All @@ -482,8 +481,8 @@ ModelBuilderImpl::CommitModelImpl(ModelImpl<ThresholdType, LeafOutputType>* out_
CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent";
CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent";
tree.AddChilds(nid);
tree.SetCategoricalSplit(nid, node->feature_id, node->default_left, false,
node->left_categories, false);
tree.SetCategoricalSplit(nid, node->feature_id, node->default_left, node->left_categories,
false);
Q.push({node->left_child, tree.LeftChild(nid)});
Q.push({node->right_child, tree.RightChild(nid)});
} else { // leaf node
Expand Down
23 changes: 18 additions & 5 deletions src/frontend/lightgbm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,16 +581,29 @@ inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
+ lgb_tree.cat_boundaries[cat_idx],
lgb_tree.cat_boundaries[cat_idx + 1]
- lgb_tree.cat_boundaries[cat_idx]);
tree.SetCategoricalSplit(new_id, split_index, false, (missing_type != MissingType::kNaN),
left_categories, false);
const bool missing_value_to_zero = missing_type != MissingType::kNaN;
bool default_left = false;
if (missing_value_to_zero) {
// If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
// we need to override the default_left flag
default_left
= (std::find(left_categories.begin(), left_categories.end(),
static_cast<uint32_t>(0)) != left_categories.end());
}
tree.SetCategoricalSplit(new_id, split_index, default_left, left_categories, false);
} else {
// numerical
const auto threshold = static_cast<double>(lgb_tree.threshold[old_id]);
const bool default_left
bool default_left
= GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
const treelite::Operator cmp_op = treelite::Operator::kLE;
tree.SetNumericalSplit(new_id, split_index, threshold, default_left,
(missing_type != MissingType::kNaN), cmp_op);
const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
if (missing_value_to_zero) {
// If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
// we need to override the default_left flag
default_left = 0.0 <= threshold;
}
tree.SetNumericalSplit(new_id, split_index, threshold, default_left, cmp_op);
}
if (!lgb_tree.internal_count.empty()) {
const int data_count = lgb_tree.internal_count[old_id];
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/xgboost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
const bst_float split_cond = node.split_cond();
tree.AddChilds(new_id);
tree.SetNumericalSplit(new_id, node.split_index(),
static_cast<float>(split_cond), node.default_left(), false, treelite::Operator::kLT);
static_cast<float>(split_cond), node.default_left(), treelite::Operator::kLT);
tree.SetGain(new_id, stat.loss_chg);
Q.push({node.cleft(), tree.LeftChild(new_id)});
Q.push({node.cright(), tree.RightChild(new_id)});
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/xgboost_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,11 @@ bool RegTreeHandler::EndObject(std::size_t memberCount) {
right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
}
output.SetCategoricalSplit(
new_id, split_indices[old_id], default_left[old_id], false, right_categories, true);
new_id, split_indices[old_id], default_left[old_id], right_categories, true);
} else {
output.SetNumericalSplit(
new_id, split_indices[old_id], split_conditions[old_id],
default_left[old_id], false, treelite::Operator::kLT);
default_left[old_id], treelite::Operator::kLT);
}
output.SetGain(new_id, loss_changes[old_id]);
Q.push({left_children[old_id], output.LeftChild(new_id)});
Expand Down
Loading

0 comments on commit 9fb79c3

Please sign in to comment.