From da8f3e9f88cc7b1ec08d0126e155d3b258b566d7 Mon Sep 17 00:00:00 2001 From: Philip Cho Date: Fri, 27 Oct 2017 14:02:24 -0700 Subject: [PATCH] Accommodate infinite thresholds Fixing issue #8 --- include/treelite/semantic.h | 19 +++++++++++++++++++ src/annotator.cc | 16 ++-------------- src/compiler/recursive.cc | 31 +++++++++++++++++++++++-------- 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/include/treelite/semantic.h b/include/treelite/semantic.h index ffe50175..af893035 100644 --- a/include/treelite/semantic.h +++ b/include/treelite/semantic.h @@ -36,6 +36,25 @@ inline std::string OpName(Operator op) { } } +/*! + * \brief perform comparison between two float's using a comparsion operator + * The comparison will be in the form [lhs] [op] [rhs]. + * \param lhs float on the left hand side + * \param op comparison operator + * \param rhs float on the right hand side + * \return whether [lhs] [op] [rhs] is true or not + */ +inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) { + switch(op) { + case Operator::kEQ: return lhs == rhs; + case Operator::kLT: return lhs < rhs; + case Operator::kLE: return lhs <= rhs; + case Operator::kGT: return lhs > rhs; + case Operator::kGE: return lhs >= rhs; + default: LOG(FATAL) << "operator undefined"; + } +} + using common::Cloneable; using common::DeepCopyUniquePtr; diff --git a/src/annotator.cc b/src/annotator.cc index 31370136..10ed5b9a 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -6,6 +6,7 @@ */ #include +#include #include #include #include @@ -34,20 +35,7 @@ void Traverse_(const treelite::Tree& tree, const Entry* data, const treelite::Operator op = node.comparison_op(); const treelite::tl_float fvalue = static_cast(data[split_index].fvalue); - switch (op) { - case treelite::Operator::kEQ: - result = (fvalue == threshold); break; - case treelite::Operator::kLT: - result = (fvalue < threshold); break; - case treelite::Operator::kLE: - result = (fvalue <= threshold); break; - case treelite::Operator::kGT: - result = (fvalue > threshold); break; - case treelite::Operator::kGE: - result = (fvalue >= threshold); break; - default: - LOG(FATAL) << "operator undefined"; - } + result = treelite::semantic::CompareWithOp(fvalue, op, threshold); } else { const auto fvalue = data[split_index].fvalue; CHECK_LT(fvalue, 64) << "Cannot have more than 64 categories"; diff --git a/src/compiler/recursive.cc b/src/compiler/recursive.cc index a09a107a..71206d45 100644 --- a/src/compiler/recursive.cc +++ b/src/compiler/recursive.cc @@ -14,6 +14,7 @@ #include #include #include +#include #include "param.h" #include "pred_transform.h" @@ -336,8 +337,14 @@ class NoQuantize : private MetadataStore { NumericSplitCondition::NumericAdapter NumericAdapter() const { return [] (Operator op, unsigned split_index, tl_float threshold) { std::ostringstream oss; - oss << "data[" << split_index << "].fvalue " - << semantic::OpName(op) << " " << threshold; + if (!std::isfinite(threshold)) { + // According to IEEE 754, the result of comparison [lhs] < infinity + // must be identical for all finite [lhs]. Same goes for operator >. + oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0"); + } else { + oss << "data[" << split_index << "].fvalue " + << semantic::OpName(op) << " " << threshold; + } return oss.str(); }; } @@ -382,11 +389,17 @@ class Quantize : private MetadataStore { tl_float threshold) { std::ostringstream oss; const auto& v = cut_pts[split_index]; - auto loc = common::binary_search(v.begin(), v.end(), threshold); - CHECK(loc != v.end()); - oss << "data[" << split_index << "].qvalue " << semantic::OpName(op) - << " " << static_cast(loc - v.begin()) * 2; - return oss.str(); + if (!std::isfinite(threshold)) { + // According to IEEE 754, the result of comparison [lhs] < infinity + // must be identical for all finite [lhs]. Same goes for operator >. + oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0"); + } else { + auto loc = common::binary_search(v.begin(), v.end(), threshold); + CHECK(loc != v.end()); + oss << "data[" << split_index << "].qvalue " << semantic::OpName(op) + << " " << static_cast(loc - v.begin()) * 2; + return oss.str(); + } }; } std::vector CommonHeader() const { @@ -527,7 +540,9 @@ ExtractCutPoints(const Model& model) { if (node.split_type() == SplitFeatureType::kNumerical) { const tl_float threshold = node.threshold(); const unsigned split_index = node.split_index(); - thresh_[split_index].insert(threshold); + if (std::isfinite(threshold)) { // ignore infinity + thresh_[split_index].insert(threshold); + } } else { CHECK(node.split_type() == SplitFeatureType::kCategorical); }