Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float64 support in treelite->FIL import and Python layer #4690

Merged
merged 58 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
86bc6b5
templatized node, forest and storage types
levsnv Feb 9, 2022
abfb605
compiles?
levsnv Feb 11, 2022
a5d970f
Merge branch 'branch-22.04' of github.com:rapidsai/cuml into fp64
levsnv Feb 11, 2022
96e2ae6
simplify leaf_output_t
levsnv Feb 11, 2022
ead832d
Merge branch 'branch-22.04' of github.com:rapidsai/cuml into fp64
levsnv Feb 12, 2022
35ad5d9
draft
levsnv Feb 16, 2022
df001e7
Merge branch 'branch-22.04' of github.com:rapidsai/cuml into fp64
levsnv Feb 16, 2022
f6efb88
fixed extra/missing instantiations
levsnv Feb 16, 2022
4bc4f99
Merge branch 'fp64' into func64
levsnv Feb 16, 2022
1bebcca
style
levsnv Feb 16, 2022
aab158b
style
levsnv Feb 16, 2022
c3f2d4f
Merge branch 'fp64' into func64
levsnv Feb 16, 2022
1236c52
removed ML::fil::init templatization, added KeyValuePair templatization
levsnv Feb 18, 2022
cc946cd
style
levsnv Feb 18, 2022
91ecb20
add the old instantiations back
levsnv Feb 18, 2022
67f900e
fixed
levsnv Feb 18, 2022
c3c4deb
style
levsnv Feb 18, 2022
68e7f65
base_node::output() now compiles.
canonizer Mar 9, 2022
fe2cf92
Merge branch 'branch-22.04' into dev-fil-fp64
canonizer Mar 9, 2022
b31069d
Fixed style.
canonizer Mar 10, 2022
c250653
F -> real_t.
canonizer Mar 10, 2022
2f658b5
Small fixes.
canonizer Mar 10, 2022
6fb105b
Updated alignment.
canonizer Mar 10, 2022
f1a10be
static_assert(real_t == float) in a number of places.
canonizer Mar 10, 2022
ce3624e
noinline -> forceinline.
canonizer Mar 10, 2022
63fadd1
Updated comment.
canonizer Mar 11, 2022
986a50d
Merge branch 'dev-fil-fp64' into enh-fil-func64
canonizer Mar 11, 2022
fda7669
Fixed many compiler errors.
canonizer Mar 12, 2022
40f7a23
Multiple changes.
canonizer Mar 14, 2022
12cc051
Fixed compilation errors; now it compiles.
canonizer Mar 14, 2022
2bc6b6e
float -> void in predict().
canonizer Mar 16, 2022
138e1dd
Some templating.
canonizer Mar 16, 2022
532686f
template_forest<real_t> for type-dependent forest members.
canonizer Mar 17, 2022
ce689f7
Instantiate forests with double.
canonizer Mar 17, 2022
98b997a
Small changes.
canonizer Mar 17, 2022
8c84cf7
Templatized BaseFilTest.
canonizer Mar 18, 2022
05de38d
Templatized child_index tests, added float64-only tests.
canonizer Mar 18, 2022
316e99a
float64 versions of multi-sum and FIL predict tests.
canonizer Mar 19, 2022
2ba5eed
compute_smem_footprint() uses float or double, based on sizeof_real.
canonizer Mar 19, 2022
0395979
Merge branch 'branch-22.04' into dev-fil64
canonizer Mar 22, 2022
ac92be7
Removed stray static_asserts.
canonizer Mar 22, 2022
2d51762
Merge branch 'branch-22.06' into dev-fil64
canonizer Mar 31, 2022
0db5e37
Merge branch 'branch-22.06' into dev-fil64
canonizer Apr 4, 2022
92c44af
Finish merge.
canonizer Apr 4, 2022
938e02a
Fixed compilation errors.
canonizer Apr 4, 2022
c665bbf
Fixed endless recursion in forest::free().
canonizer Apr 4, 2022
175837a
Removed changes to fil.h.
canonizer Apr 4, 2022
886c649
Refactored tests.
canonizer Apr 4, 2022
1426c14
noinline -> forceinline.
canonizer Apr 4, 2022
d1fe2e0
Merge branch 'branch-22.06' into dev-fil64
canonizer Apr 6, 2022
56ecd51
Addressed review comments.
canonizer Apr 6, 2022
9098b86
float64 support in treelite->FIL import.
canonizer Apr 7, 2022
1fd8624
float64 support in FIL Python layer.
canonizer Apr 7, 2022
aaa3261
get_forest{32,64} in fil.pyx.
canonizer Apr 7, 2022
afa4952
Merge branch 'branch-22.06' into dev-fil-tl64
canonizer Apr 7, 2022
606ee1d
Initializing forest_data with forest_variant(NULL).
canonizer Apr 7, 2022
ca7610f
Merge branch 'branch-22.06' into dev-fil-tl64
canonizer Apr 9, 2022
0d2c512
Addressed review comments.
canonizer Apr 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class FIL : public RegressionFixture<float> {
.threads_per_tree = 1,
.n_items = 0,
.pforest_shape_str = nullptr};
ML::fil::from_treelite(*handle, &forest, model, &tl_params);
ML::fil::forest_variant forest_variant;
ML::fil::from_treelite(*handle, &forest_variant, model, &tl_params);
forest = std::get<ML::fil::forest_t<float>>(forest_variant);

// only time prediction
this->loopOnState(state, [this]() {
Expand Down
11 changes: 10 additions & 1 deletion cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <stddef.h>

#include <variant> // for std::get<>, std::variant<>

#include <cuml/ensemble/treelite_defs.hpp>

namespace raft {
Expand Down Expand Up @@ -76,6 +78,13 @@ struct forest;
template <typename real_t>
using forest_t = forest<real_t>*;

/** forest32_t and forest64_t are definitions required in Cython */
using forest32_t = forest<float>*;
using forest64_t = forest<double>*;

/** forest_variant is used to get a forest represented with either float or double. */
using forest_variant = std::variant<forest_t<float>, forest_t<double>>;

/** MAX_N_ITEMS determines the maximum allowed value for tl_params::n_items */
constexpr int MAX_N_ITEMS = 4;

Expand Down Expand Up @@ -116,7 +125,7 @@ struct treelite_params_t {
*/
// TODO (canonizer): use std::variant<forest_t<float> forest_t<double>>* for pforest
canonizer marked this conversation as resolved.
Show resolved Hide resolved
void from_treelite(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_variant* pforest,
ModelHandle model,
const treelite_params_t* tl_params);

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#endif // __CUDA_ARCH__
#endif // CUDA_PRAGMA_UNROLL

#define INLINE_CONFIG __forceinline__
#define INLINE_CONFIG __noinline__
canonizer marked this conversation as resolved.
Show resolved Hide resolved

namespace ML {
namespace fil {
Expand Down
69 changes: 39 additions & 30 deletions cpp/src/fil/treelite_import.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <cstddef> // for std::size_t
#include <cstdint> // for uint8_t
#include <iosfwd> // for ios, stringstream
#include <limits> // for std::numeric_limits
#include <stack> // for std::stack
#include <string> // for std::string
#include <type_traits> // for std::is_same
Expand Down Expand Up @@ -223,7 +224,8 @@ cat_sets_owner allocate_cat_sets_owner(const tl::ModelImpl<T, L>& model)
return cat_sets;
}

void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op)
template <typename real_t>
void adjust_threshold(real_t* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op)
{
// in treelite (take left node if val [op] threshold),
// the meaning of the condition is reversed compared to FIL;
Expand All @@ -237,12 +239,12 @@ void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator co
case tl::Operator::kLT: break;
case tl::Operator::kLE:
// x <= y is equivalent to x < y', where y' is the next representable float
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<float>::infinity());
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<real_t>::infinity());
break;
case tl::Operator::kGT:
// x > y is equivalent to x >= y', where y' is the next representable float
// left and right still need to be swapped
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<float>::infinity());
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<real_t>::infinity());
case tl::Operator::kGE:
// swap left and right
*swap_child_nodes = !*swap_child_nodes;
Expand Down Expand Up @@ -279,7 +281,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node,
const tl::Tree<T, L>& tl_tree,
int tl_node_id,
const forest_params_t& forest_params,
std::vector<float>* vector_leaf,
std::vector<typename fil_node_t::real_type>* vector_leaf,
size_t* leaf_counter)
{
auto vec = tl_tree.LeafVector(tl_node_id);
Expand All @@ -301,7 +303,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node,
}
case leaf_algo_t::FLOAT_UNARY_BINARY:
case leaf_algo_t::GROVE_PER_CLASS:
fil_node->val.f = static_cast<float>(tl_tree.LeafValue(tl_node_id));
fil_node->val.f = static_cast<typename fil_node_t::real_type>(tl_tree.LeafValue(tl_node_id));
ASSERT(!tl_tree.HasLeafVector(tl_node_id),
"some but not all treelite leaves have leaf_vector()");
break;
Expand All @@ -323,14 +325,15 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
cat_sets_owner* cat_sets,
std::size_t* bit_pool_offset)
{
using real_t = typename fil_node_t::real_type;
int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id);
val_t<float> split = {.f = NAN}; // yes there's a default initializer already
val_t<real_t> split = {.f = std::numeric_limits<real_t>::quiet_NaN()};
int feature_id = tree.SplitIndex(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(tl_node_id).size() > 0;
bool swap_child_nodes = false;
if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) {
split.f = static_cast<float>(tree.Threshold(tl_node_id));
split.f = static_cast<real_t>(tree.Threshold(tl_node_id));
adjust_threshold(&split.f, &swap_child_nodes, tree.ComparisonOp(tl_node_id));
} else if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical) {
// for FIL, the list of categories is always for the right child
Expand All @@ -346,14 +349,14 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
}
} else {
// always branch left in FIL. Already accounted for Treelite branching direction above.
split.f = NAN;
split.f = std::numeric_limits<real_t>::quiet_NaN();
}
} else {
ASSERT(false, "only numerical and categorical split nodes are supported");
}
bool default_left = tree.DefaultLeft(tl_node_id) ^ swap_child_nodes;
fil_node_t node(
val_t<float>{}, split, feature_id, default_left, false, is_categorical, fil_left_child);
val_t<real_t>{}, split, feature_id, default_left, false, is_categorical, fil_left_child);
return conversion_state<fil_node_t>{node, swap_child_nodes};
}

Expand All @@ -363,7 +366,7 @@ int tree2fil(std::vector<fil_node_t>& nodes,
const tl::Tree<T, L>& tree,
std::size_t tree_idx,
const forest_params_t& forest_params,
std::vector<float>* vector_leaf,
std::vector<typename fil_node_t::real_type>* vector_leaf,
std::size_t* leaf_counter,
cat_sets_owner* cat_sets)
{
Expand Down Expand Up @@ -443,10 +446,11 @@ std::stringstream depth_hist_and_max(const tl::ModelImpl<T, L>& model)
forest_shape << "Total: branches: " << total_branches << " leaves: " << total_leaves
<< " nodes: " << total_nodes << endl;
forest_shape << "Avg nodes per tree: " << setprecision(2)
<< total_nodes / (float)hist[0].n_branch_nodes << endl;
<< total_nodes / static_cast<double>(hist[0].n_branch_nodes) << endl;
forest_shape.copyfmt(default_state);
forest_shape << "Leaf depth: min: " << min_leaf_depth << " avg: " << setprecision(2) << fixed
<< leaves_times_depth / (float)total_leaves << " max: " << hist.size() - 1 << endl;
<< leaves_times_depth / static_cast<double>(total_leaves)
<< " max: " << hist.size() - 1 << endl;
forest_shape.copyfmt(default_state);

vector<char> hist_bytes(hist.size() * sizeof(hist[0]));
Expand Down Expand Up @@ -575,9 +579,10 @@ void node_traits<node_t>::check(const treelite::ModelImpl<threshold_t, leaf_t>&

template <typename fil_node_t, typename threshold_t, typename leaf_t>
struct tl2fil_t {
using real_t = typename fil_node_t::real_type;
std::vector<int> roots_;
std::vector<fil_node_t> nodes_;
std::vector<float> vector_leaf_;
std::vector<real_t> vector_leaf_;
forest_params_t params_;
cat_sets_owner cat_sets_;
const tl::ModelImpl<threshold_t, leaf_t>& model_;
Expand Down Expand Up @@ -631,7 +636,7 @@ struct tl2fil_t {
}

/// initializes FIL forest object, to be ready to infer
void init_forest(const raft::handle_t& handle, forest_t<float>* pforest)
void init_forest(const raft::handle_t& handle, forest_t<real_t>* pforest)
{
ML::fil::init(
handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), &params_);
Expand All @@ -646,7 +651,7 @@ struct tl2fil_t {

template <typename fil_node_t, typename threshold_t, typename leaf_t>
void convert(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_t<typename fil_node_t::real_type>* pforest,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t& tl_params)
{
Expand All @@ -664,24 +669,21 @@ constexpr bool type_supported()

template <typename threshold_t, typename leaf_t>
void from_treelite(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_variant* pforest_variant,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params)
{
// floating-point type used for model representation
using real_t = decltype(threshold_t(0) + leaf_t(0));

// get the pointer to the right forest variant
*pforest_variant = (forest_t<real_t>)nullptr;
forest_t<real_t>* pforest = &std::get<forest_t<real_t>>(*pforest_variant);

// Invariants on threshold and leaf types
static_assert(type_supported<threshold_t>(),
"Model must contain float32 or float64 thresholds for splits");
ASSERT(type_supported<leaf_t>(), "Models with integer leaf output are not yet supported");
// Display appropriate warnings when float64 values are being casted into
// float32, as FIL only supports inferencing with float32 for the time being
if (std::is_same<threshold_t, double>::value || std::is_same<leaf_t, double>::value) {
CUML_LOG_WARN(
"Casting all thresholds and leaf values to float32, as FIL currently "
"doesn't support inferencing models with float64 values. "
"This may lead to predictions with reduced accuracy.");
}
// same as std::common_type: float+double=double, float+int64_t=float
using real_t = decltype(threshold_t(0) + leaf_t(0));

storage_type_t storage_type = tl_params->storage_type;
// build dense trees by default
Expand All @@ -702,18 +704,25 @@ void from_treelite(const raft::handle_t& handle,

switch (storage_type) {
case storage_type_t::DENSE:
convert<dense_node<float>>(handle, pforest, model, *tl_params);
convert<dense_node<real_t>>(handle, pforest, model, *tl_params);
break;
case storage_type_t::SPARSE:
convert<sparse_node16<float>>(handle, pforest, model, *tl_params);
convert<sparse_node16<real_t>>(handle, pforest, model, *tl_params);
break;
case storage_type_t::SPARSE8:
// SPARSE8 is only supported for float32
if constexpr (std::is_same_v<real_t, float>) {
convert<sparse_node8>(handle, pforest, model, *tl_params);
} else {
ASSERT(false, "SPARSE8 is only supported for float32 treelite models");
}
break;
case storage_type_t::SPARSE8: convert<sparse_node8>(handle, pforest, model, *tl_params); break;
default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE");
}
}

void from_treelite(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_variant* pforest,
ModelHandle model,
const treelite_params_t* tl_params)
{
Expand Down
Loading