Skip to content

Commit

Permalink
Use double-precision to store split thresholds (dmlc#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 authored Oct 9, 2018
1 parent c9023e7 commit 79eb179
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion include/treelite/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace treelite {

/*! \brief float type to be used internally */
typedef float tl_float;
typedef double tl_float;
/*! \brief feature split type */
enum class SplitFeatureType : int8_t {
kNone, kNumerical, kCategorical
Expand Down
6 changes: 3 additions & 3 deletions include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const char* opname,
float threshold, int default_left,
double threshold, int default_left,
int left_child_key,
int right_child_key);
/*!
Expand Down Expand Up @@ -375,7 +375,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode(
*/
TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle,
int node_key,
float leaf_value);
double leaf_value);
/*!
* \brief Turn an empty node into a leaf vector node
* The leaf vector (collection of multiple leaf weights per leaf node) is
Expand All @@ -389,7 +389,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle,
*/
TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle,
int node_key,
const float* leaf_vector,
const double* leaf_vector,
size_t leaf_vector_len);
/*!
* \brief Create a new model builder
Expand Down
6 changes: 3 additions & 3 deletions python/treelite/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,13 @@ def set_leaf_node(self, leaf_value):
_check_call(_LIB.TreeliteTreeBuilderSetLeafVectorNode(
self.tree.handle,
ctypes.c_int(self.node_key),
c_array(ctypes.c_float, leaf_value),
c_array(ctypes.c_double, leaf_value),
ctypes.c_size_t(len(leaf_value))))
else:
_check_call(_LIB.TreeliteTreeBuilderSetLeafNode(
self.tree.handle,
ctypes.c_int(self.node_key),
ctypes.c_float(leaf_value)))
ctypes.c_double(leaf_value)))
self.empty = False
except AttributeError:
raise TreeliteError('This node has never been inserted into a tree; '\
Expand Down Expand Up @@ -494,7 +494,7 @@ def set_numerical_test_node(self, feature_id, opname, threshold,
self.tree.handle,
ctypes.c_int(self.node_key),
ctypes.c_uint(feature_id), c_str(opname),
ctypes.c_float(threshold),
ctypes.c_double(threshold),
ctypes.c_int(1 if default_left else 0),
ctypes.c_int(left_child_key),
ctypes.c_int(right_child_key)))
Expand Down
6 changes: 3 additions & 3 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key) {
int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const char* opname,
float threshold, int default_left,
double threshold, int default_left,
int left_child_key,
int right_child_key) {
API_BEGIN();
Expand Down Expand Up @@ -430,7 +430,7 @@ int TreeliteTreeBuilderSetCategoricalTestNode(
}

int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key,
float leaf_value) {
double leaf_value) {
API_BEGIN();
auto builder = static_cast<frontend::TreeBuilder*>(handle);
CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object";
Expand All @@ -441,7 +441,7 @@ int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key,

int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle,
int node_key,
const float* leaf_vector,
const double* leaf_vector,
size_t leaf_vector_len) {
API_BEGIN();
auto builder = static_cast<frontend::TreeBuilder*>(handle);
Expand Down
6 changes: 3 additions & 3 deletions src/frontend/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ void ExportProtobufModel(const char* filename, const Model& model) {
<< "The length of leaf vector must be identical to the "
<< "number of output groups";
for (tl_float e : leaf_vector) {
proto_node->add_leaf_vector(static_cast<float>(e));
proto_node->add_leaf_vector(static_cast<double>(e));
}
CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
} else { // leaf node with scalar output
Expand All @@ -270,7 +270,7 @@ void ExportProtobufModel(const char* filename, const Model& model) {
<< "a leaf vector, *no other* leaf node can use a leaf vector";
flag_leaf_vector = 0; // now no leaf can use leaf vector

proto_node->set_leaf_value(static_cast<float>(tree[nid].leaf_value()));
proto_node->set_leaf_value(static_cast<double>(tree[nid].leaf_value()));
}
} else if (tree[nid].split_type() == SplitFeatureType::kNumerical) {
// numerical split
Expand All @@ -283,7 +283,7 @@ void ExportProtobufModel(const char* filename, const Model& model) {
proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
proto_node->set_op(OpName(op));
proto_node->set_threshold(threshold);
proto_node->set_threshold(static_cast<double>(threshold));
Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
} else { // categorical split
Expand Down
6 changes: 3 additions & 3 deletions src/tree.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ message Node {
// expression evaluates to true; the right
// child is taken otherwise.
// missing if leaf or categorical split
optional float threshold = 7; // Decision threshold
optional double threshold = 7; // Decision threshold
// missing if leaf or categorical split
repeated uint32 left_categories = 8;
// List of all categories belonging to
// the left child. All other categories
// will belong to the right child.
// missing if leaf or numerical split
optional float leaf_value = 9; // Leaf value; missing if non-leaf
optional double leaf_value = 9; // Leaf value; missing if non-leaf
// also missing if leaf_vector field exists
repeated float leaf_vector = 10; // Usually missing; only used for random
repeated double leaf_vector = 10; // Usually missing; only used for random
// forests with multi-class classification
optional uint64 data_count = 11; // number of data points whose traversal
// paths include this node. May be
Expand Down
13 changes: 12 additions & 1 deletion tests/python/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ def test_model_builder2(self):
def test_model_builder3(self):
"""Test programmatic model construction using scikit-learn random forest"""
X, y = load_iris(return_X_y=True)
clf = RandomForestClassifier(max_depth=3, random_state=0)
clf = RandomForestClassifier(max_depth=3, random_state=0, n_estimators=10)
clf.fit(X, y)
expected_prob = clf.predict_proba(X)

Expand Down Expand Up @@ -1425,6 +1425,17 @@ def process_tree(sklearn_tree):
out_prob = predictor.predict(batch)
assert_almost_equal(out_prob, expected_prob)

# Test round-trip with Protobuf
model.export_protobuf('./my.buffer')
model = treelite.Model.load('./my.buffer', 'protobuf')
for toolchain in os_compatible_toolchains():
model.export_lib(toolchain=toolchain, libpath=libpath,
params={'annotate_in': './annotation.json'}, verbose=True)
predictor = treelite.runtime.Predictor(libpath=libpath, verbose=True)
batch = treelite.runtime.Batch.from_npy2d(X)
out_prob = predictor.predict(batch)
assert_almost_equal(out_prob, expected_prob)

def test_node_insert_delete(self):
"""Test ability to add and remove nodes"""
builder = treelite.ModelBuilder(num_feature=3)
Expand Down

0 comments on commit 79eb179

Please sign in to comment.