Skip to content

Commit

Permalink
Add robust bounds check to getters of Tree class (#237)
Browse files Browse the repository at this point in the history
* Add bounds check to getters of Tree class

* LeafVector() and MatchingCategories() should return empty vector instead of throwing out-of-bound exception

* Fix typo
  • Loading branch information
hcho3 authored Dec 29, 2020
1 parent 5b73f72 commit 191a474
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 45 deletions.
85 changes: 54 additions & 31 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ class ContiguousArray {
inline void Clear();
inline void PushBack(T t);
inline void Extend(const std::vector<T>& other);
/* Unsafe access, no bounds checking */
inline T& operator[](size_t idx);
inline const T& operator[](size_t idx) const;
/* Safe access, with bounds checking */
inline T& at(size_t idx);
inline const T& at(size_t idx) const;
/* Safe access, with bounds checking + check against non-existent node (<0) */
inline T& at(int idx);
inline const T& at(int idx) const;
static_assert(std::is_pod<T>::value, "T must be POD");

private:
Expand Down Expand Up @@ -300,14 +307,14 @@ class Tree {
* \param nid ID of node being queried
*/
inline int LeftChild(int nid) const {
return nodes_[nid].cleft_;
return nodes_.at(nid).cleft_;
}
/*!
* \brief index of the node's right child
* \param nid ID of node being queried
*/
inline int RightChild(int nid) const {
return nodes_[nid].cright_;
return nodes_.at(nid).cright_;
}
/*!
* \brief index of the node's "default" child, used when feature is missing
Expand All @@ -321,63 +328,65 @@ class Tree {
* \param nid ID of node being queried
*/
inline uint32_t SplitIndex(int nid) const {
return (nodes_[nid].sindex_ & ((1U << 31U) - 1U));
return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
}
/*!
* \brief whether to use the left child node, when the feature in the split condition is missing
* \param nid ID of node being queried
*/
inline bool DefaultLeft(int nid) const {
return (nodes_[nid].sindex_ >> 31U) != 0;
return (nodes_.at(nid).sindex_ >> 31U) != 0;
}
/*!
* \brief whether the node is leaf node
* \param nid ID of node being queried
*/
inline bool IsLeaf(int nid) const {
return nodes_[nid].cleft_ == -1;
return nodes_.at(nid).cleft_ == -1;
}
/*!
* \brief get leaf value of the leaf node
* \param nid ID of node being queried
*/
inline LeafOutputType LeafValue(int nid) const {
return (nodes_[nid].info_).leaf_value;
return (nodes_.at(nid).info_).leaf_value;
}
/*!
* \brief get leaf vector of the leaf node; useful for multi-class random forest classifier
* \param nid ID of node being queried
*/
inline std::vector<LeafOutputType> LeafVector(int nid) const {
if (static_cast<size_t>(nid) > leaf_vector_offset_.Size()) {
throw std::runtime_error("nid too large");
const size_t offset_begin = leaf_vector_offset_.at(nid);
const size_t offset_end = leaf_vector_offset_.at(nid + 1);
if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
// Return empty vector, to indicate the lack of leaf vector
return std::vector<LeafOutputType>();
}
return std::vector<LeafOutputType>(&leaf_vector_[leaf_vector_offset_[nid]],
&leaf_vector_[leaf_vector_offset_[nid + 1]]);
return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
&leaf_vector_[offset_end]);
// Use unsafe access here, since we may need to take the address of one past the last
// element, to follow with the range semantic of std::vector<>.
}
/*!
* \brief tests whether the leaf node has a non-empty leaf vector
* \param nid ID of node being queried
*/
inline bool HasLeafVector(int nid) const {
if (static_cast<size_t>(nid) > leaf_vector_offset_.Size()) {
throw std::runtime_error("nid too large");
}
return leaf_vector_offset_[nid] != leaf_vector_offset_[nid + 1];
return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
}
/*!
* \brief get threshold of the node
* \param nid ID of node being queried
*/
inline ThresholdType Threshold(int nid) const {
return (nodes_[nid].info_).threshold;
return (nodes_.at(nid).info_).threshold;
}
/*!
* \brief get comparison operator
* \param nid ID of node being queried
*/
inline Operator ComparisonOp(int nid) const {
return nodes_[nid].cmp_;
return nodes_.at(nid).cmp_;
}
/*!
* \brief Get list of all categories belonging to the left/right child node. See the
Expand All @@ -388,69 +397,83 @@ class Tree {
* \param nid ID of node being queried
*/
inline std::vector<uint32_t> MatchingCategories(int nid) const {
if (static_cast<size_t>(nid) > matching_categories_offset_.Size()) {
throw std::runtime_error("nid too large");
const size_t offset_begin = matching_categories_offset_.at(nid);
const size_t offset_end = matching_categories_offset_.at(nid + 1);
if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
// Return empty vector, to indicate the lack of any matching categories
// The node might be a numerical split
return std::vector<uint32_t>();
}
return std::vector<uint32_t>(&matching_categories_[matching_categories_offset_[nid]],
&matching_categories_[matching_categories_offset_[nid + 1]]);
return std::vector<uint32_t>(&matching_categories_[offset_begin],
&matching_categories_[offset_end]);
// Use unsafe access here, since we may need to take the address of one past the last
// element, to follow with the range semantic of std::vector<>.
}
/*!
* \brief tests whether the node has a non-empty list for matching categories. See
* MatchingCategories() for the definition of matching categories.
* \param nid ID of node being queried
*/
inline bool HasMatchingCategories(int nid) const {
return matching_categories_offset_.at(nid) != matching_categories_offset_.at(nid + 1);
}
/*!
* \brief get feature split type
* \param nid ID of node being queried
*/
inline SplitFeatureType SplitType(int nid) const {
return nodes_[nid].split_type_;
return nodes_.at(nid).split_type_;
}
/*!
* \brief test whether this node has data count
* \param nid ID of node being queried
*/
inline bool HasDataCount(int nid) const {
return nodes_[nid].data_count_present_;
return nodes_.at(nid).data_count_present_;
}
/*!
* \brief get data count
* \param nid ID of node being queried
*/
inline uint64_t DataCount(int nid) const {
return nodes_[nid].data_count_;
return nodes_.at(nid).data_count_;
}

/*!
* \brief test whether this node has hessian sum
* \param nid ID of node being queried
*/
inline bool HasSumHess(int nid) const {
return nodes_[nid].sum_hess_present_;
return nodes_.at(nid).sum_hess_present_;
}
/*!
* \brief get hessian sum
* \param nid ID of node being queried
*/
inline double SumHess(int nid) const {
return nodes_[nid].sum_hess_;
return nodes_.at(nid).sum_hess_;
}
/*!
* \brief test whether this node has gain value
* \param nid ID of node being queried
*/
inline bool HasGain(int nid) const {
return nodes_[nid].gain_present_;
return nodes_.at(nid).gain_present_;
}
/*!
* \brief get gain value
* \param nid ID of node being queried
*/
inline double Gain(int nid) const {
return nodes_[nid].gain_;
return nodes_.at(nid).gain_;
}
/*!
* \brief test whether the list given by MatchingCategories(nid) is associated with the right
* child node or the left child node
* \param nid ID of node being queried
*/
inline bool CategoriesListRightChild(int nid) const {
return nodes_[nid].categories_list_right_child_;
return nodes_.at(nid).categories_list_right_child_;
}

/** Setters **/
Expand Down Expand Up @@ -498,7 +521,7 @@ class Tree {
* \param sum_hess hessian sum
*/
inline void SetSumHess(int nid, double sum_hess) {
Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
node.sum_hess_ = sum_hess;
node.sum_hess_present_ = true;
}
Expand All @@ -508,7 +531,7 @@ class Tree {
* \param data_count data count
*/
inline void SetDataCount(int nid, uint64_t data_count) {
Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
node.data_count_ = data_count;
node.data_count_present_ = true;
}
Expand All @@ -518,7 +541,7 @@ class Tree {
* \param gain gain value
*/
inline void SetGain(int nid, double gain) {
Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
node.gain_ = gain;
node.gain_present_ = true;
}
Expand Down
60 changes: 48 additions & 12 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,42 @@ ContiguousArray<T>::operator[](size_t idx) const {
return buffer_[idx];
}

template <typename T>
inline T&
ContiguousArray<T>::at(size_t idx) {
if (idx >= Size()) {
throw std::runtime_error("nid out of range");
}
return buffer_[idx];
}

template <typename T>
inline const T&
ContiguousArray<T>::at(size_t idx) const {
if (idx >= Size()) {
throw std::runtime_error("nid out of range");
}
return buffer_[idx];
}

template <typename T>
inline T&
ContiguousArray<T>::at(int idx) {
if (idx < 0 || static_cast<size_t>(idx) >= Size()) {
throw std::runtime_error("nid out of range");
}
return buffer_[static_cast<size_t>(idx)];
}

template <typename T>
inline const T&
ContiguousArray<T>::at(int idx) const {
if (idx < 0 || static_cast<size_t>(idx) >= Size()) {
throw std::runtime_error("nid out of range");
}
return buffer_[static_cast<size_t>(idx)];
}

template<typename Container>
inline std::vector<std::pair<std::string, std::string> >
ModelParam::InitAllowUnknown(const Container& kwargs) {
Expand Down Expand Up @@ -489,7 +525,7 @@ Tree<ThresholdType, LeafOutputType>::Init() {
matching_categories_.Clear();
matching_categories_offset_.Resize(2, 0);
nodes_.Resize(1);
nodes_[0].Init();
nodes_.at(0).Init();
SetLeaf(0, static_cast<LeafOutputType>(0));
}

Expand All @@ -498,8 +534,8 @@ inline void
Tree<ThresholdType, LeafOutputType>::AddChilds(int nid) {
const int cleft = this->AllocNode();
const int cright = this->AllocNode();
nodes_[nid].cleft_ = cleft;
nodes_[nid].cright_ = cright;
nodes_.at(nid).cleft_ = cleft;
nodes_.at(nid).cright_ = cright;
}

template <typename ThresholdType, typename LeafOutputType>
Expand Down Expand Up @@ -535,7 +571,7 @@ template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::SetNumericalSplit(
int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
if (split_index >= ((1U << 31U) - 1)) {
throw std::runtime_error("split_index too big");
}
Expand All @@ -561,7 +597,7 @@ Tree<ThresholdType, LeafOutputType>::SetCategoricalSplit(
if (end_oft != matching_categories_.Size()) {
throw std::runtime_error("Invariant violated");
}
if (!std::all_of(&matching_categories_offset_[nid + 1], matching_categories_offset_.End(),
if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
[end_oft](size_t x) { return (x == end_oft); })) {
throw std::runtime_error("Invariant violated");
}
Expand All @@ -570,11 +606,11 @@ Tree<ThresholdType, LeafOutputType>::SetCategoricalSplit(
if (new_end_oft != matching_categories_.Size()) {
throw std::runtime_error("Invariant violated");
}
std::for_each(&matching_categories_offset_[nid + 1], matching_categories_offset_.End(),
std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
[new_end_oft](size_t& x) { x = new_end_oft; });
std::sort(&matching_categories_[end_oft], matching_categories_.End());
std::sort(&matching_categories_.at(end_oft), matching_categories_.End());

Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
if (default_left) split_index |= (1U << 31U);
node.sindex_ = split_index;
node.split_type_ = SplitFeatureType::kCategorical;
Expand All @@ -584,7 +620,7 @@ Tree<ThresholdType, LeafOutputType>::SetCategoricalSplit(
template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
(node.info_).leaf_value = value;
node.cleft_ = -1;
node.cright_ = -1;
Expand All @@ -600,7 +636,7 @@ Tree<ThresholdType, LeafOutputType>::SetLeafVector(
if (end_oft != leaf_vector_.Size()) {
throw std::runtime_error("Invariant violated");
}
if (!std::all_of(&leaf_vector_offset_[nid + 1], leaf_vector_offset_.End(),
if (!std::all_of(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
[end_oft](size_t x) { return (x == end_oft); })) {
throw std::runtime_error("Invariant violated");
}
Expand All @@ -609,10 +645,10 @@ Tree<ThresholdType, LeafOutputType>::SetLeafVector(
if (new_end_oft != leaf_vector_.Size()) {
throw std::runtime_error("Invariant violated");
}
std::for_each(&leaf_vector_offset_[nid + 1], leaf_vector_offset_.End(),
std::for_each(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
[new_end_oft](size_t& x) { x = new_end_oft; });

Node& node = nodes_[nid];
Node& node = nodes_.at(nid);
node.cleft_ = -1;
node.cright_ = -1;
node.split_type_ = SplitFeatureType::kNone;
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/failsafe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ inline std::pair<std::string, std::string> FormatNodesArray(
"cright"_a = -1);
} else {
CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
&& tree.MatchingCategories(nid).empty())
&& !tree.HasMatchingCategories(nid))
<< "categorical splits are not supported in FailSafeCompiler";
nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
"sindex"_a
Expand Down Expand Up @@ -185,7 +185,7 @@ inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
val = {0, static_cast<float>(tree.LeafValue(nid)), -1, -1};
} else {
CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
&& tree.MatchingCategories(nid).empty())
&& !tree.HasMatchingCategories(nid))
<< "categorical splits are not supported in FailSafeCompiler";
val = {(tree.SplitIndex(nid) | (static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31)),
static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
Expand Down

0 comments on commit 191a474

Please sign in to comment.