Skip to content

Commit

Permalink
Accommodate infinite thresholds
Browse files Browse the repository at this point in the history
Fixing issue #8
  • Loading branch information
hcho3 committed Oct 27, 2017
1 parent 0988e42 commit da8f3e9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
19 changes: 19 additions & 0 deletions include/treelite/semantic.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ inline std::string OpName(Operator op) {
}
}

/*!
* \brief perform comparison between two float's using a comparsion operator
* The comparison will be in the form [lhs] [op] [rhs].
* \param lhs float on the left hand side
* \param op comparison operator
* \param rhs float on the right hand side
* \return whether [lhs] [op] [rhs] is true or not
*/
inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) {
switch(op) {
case Operator::kEQ: return lhs == rhs;
case Operator::kLT: return lhs < rhs;
case Operator::kLE: return lhs <= rhs;
case Operator::kGT: return lhs > rhs;
case Operator::kGE: return lhs >= rhs;
default: LOG(FATAL) << "operator undefined";
}
}

using common::Cloneable;
using common::DeepCopyUniquePtr;

Expand Down
16 changes: 2 additions & 14 deletions src/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include <treelite/annotator.h>
#include <treelite/semantic.h>
#include <treelite/omp.h>
#include <cstdint>
#include <limits>
Expand Down Expand Up @@ -34,20 +35,7 @@ void Traverse_(const treelite::Tree& tree, const Entry* data,
const treelite::Operator op = node.comparison_op();
const treelite::tl_float fvalue
= static_cast<treelite::tl_float>(data[split_index].fvalue);
switch (op) {
case treelite::Operator::kEQ:
result = (fvalue == threshold); break;
case treelite::Operator::kLT:
result = (fvalue < threshold); break;
case treelite::Operator::kLE:
result = (fvalue <= threshold); break;
case treelite::Operator::kGT:
result = (fvalue > threshold); break;
case treelite::Operator::kGE:
result = (fvalue >= threshold); break;
default:
LOG(FATAL) << "operator undefined";
}
result = treelite::semantic::CompareWithOp(fvalue, op, threshold);
} else {
const auto fvalue = data[split_index].fvalue;
CHECK_LT(fvalue, 64) << "Cannot have more than 64 categories";
Expand Down
31 changes: 23 additions & 8 deletions src/compiler/recursive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <queue>
#include <algorithm>
#include <iterator>
#include <cmath>
#include "param.h"
#include "pred_transform.h"

Expand Down Expand Up @@ -336,8 +337,14 @@ class NoQuantize : private MetadataStore {
NumericSplitCondition::NumericAdapter NumericAdapter() const {
return [] (Operator op, unsigned split_index, tl_float threshold) {
std::ostringstream oss;
oss << "data[" << split_index << "].fvalue "
<< semantic::OpName(op) << " " << threshold;
if (!std::isfinite(threshold)) {
// According to IEEE 754, the result of comparison [lhs] < infinity
// must be identical for all finite [lhs]. Same goes for operator >.
oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
} else {
oss << "data[" << split_index << "].fvalue "
<< semantic::OpName(op) << " " << threshold;
}
return oss.str();
};
}
Expand Down Expand Up @@ -382,11 +389,17 @@ class Quantize : private MetadataStore {
tl_float threshold) {
std::ostringstream oss;
const auto& v = cut_pts[split_index];
auto loc = common::binary_search(v.begin(), v.end(), threshold);
CHECK(loc != v.end());
oss << "data[" << split_index << "].qvalue " << semantic::OpName(op)
<< " " << static_cast<size_t>(loc - v.begin()) * 2;
return oss.str();
if (!std::isfinite(threshold)) {
// According to IEEE 754, the result of comparison [lhs] < infinity
// must be identical for all finite [lhs]. Same goes for operator >.
oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
} else {
auto loc = common::binary_search(v.begin(), v.end(), threshold);
CHECK(loc != v.end());
oss << "data[" << split_index << "].qvalue " << semantic::OpName(op)
<< " " << static_cast<size_t>(loc - v.begin()) * 2;
return oss.str();
}
};
}
std::vector<std::string> CommonHeader() const {
Expand Down Expand Up @@ -527,7 +540,9 @@ ExtractCutPoints(const Model& model) {
if (node.split_type() == SplitFeatureType::kNumerical) {
const tl_float threshold = node.threshold();
const unsigned split_index = node.split_index();
thresh_[split_index].insert(threshold);
if (std::isfinite(threshold)) { // ignore infinity
thresh_[split_index].insert(threshold);
}
} else {
CHECK(node.split_type() == SplitFeatureType::kCategorical);
}
Expand Down

0 comments on commit da8f3e9

Please sign in to comment.