diff --git a/include/treelite/base.h b/include/treelite/base.h index 9cde6f7f..32a33db6 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -14,7 +14,7 @@ namespace treelite { /*! \brief float type to be used internally */ -typedef float tl_float; +typedef double tl_float; /*! \brief feature split type */ enum class SplitFeatureType : int8_t { kNone, kNumerical, kCategorical diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 0a00ee79..a8884185 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -338,7 +338,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode( TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname, - float threshold, int default_left, + double threshold, int default_left, int left_child_key, int right_child_key); /*! @@ -375,7 +375,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode( */ TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, - float leaf_value); + double leaf_value); /*! * \brief Turn an empty node into a leaf vector node * The leaf vector (collection of multiple leaf weights per leaf node) is @@ -389,7 +389,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, */ TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, - const float* leaf_vector, + const double* leaf_vector, size_t leaf_vector_len); /*! * \brief Create a new model builder diff --git a/python/treelite/frontend.py b/python/treelite/frontend.py index 670e8c81..8fe0a282 100644 --- a/python/treelite/frontend.py +++ b/python/treelite/frontend.py @@ -438,13 +438,13 @@ def set_leaf_node(self, leaf_value): _check_call(_LIB.TreeliteTreeBuilderSetLeafVectorNode( self.tree.handle, ctypes.c_int(self.node_key), - c_array(ctypes.c_float, leaf_value), + c_array(ctypes.c_double, leaf_value), ctypes.c_size_t(len(leaf_value)))) else: _check_call(_LIB.TreeliteTreeBuilderSetLeafNode( self.tree.handle, ctypes.c_int(self.node_key), - ctypes.c_float(leaf_value))) + ctypes.c_double(leaf_value))) self.empty = False except AttributeError: raise TreeliteError('This node has never been inserted into a tree; '\ @@ -494,7 +494,7 @@ def set_numerical_test_node(self, feature_id, opname, threshold, self.tree.handle, ctypes.c_int(self.node_key), ctypes.c_uint(feature_id), c_str(opname), - ctypes.c_float(threshold), + ctypes.c_double(threshold), ctypes.c_int(1 if default_left else 0), ctypes.c_int(left_child_key), ctypes.c_int(right_child_key))) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 65c0ac2a..6cf596fd 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -389,7 +389,7 @@ int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key) { int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname, - float threshold, int default_left, + double threshold, int default_left, int left_child_key, int right_child_key) { API_BEGIN(); @@ -430,7 +430,7 @@ int TreeliteTreeBuilderSetCategoricalTestNode( } int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, - float leaf_value) { + double leaf_value) { API_BEGIN(); auto builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; @@ -441,7 +441,7 @@ int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, - const float* leaf_vector, + const double* leaf_vector, size_t leaf_vector_len) { API_BEGIN(); auto builder = static_cast(handle); diff --git a/src/frontend/protobuf.cc b/src/frontend/protobuf.cc index 56a1780e..773dc710 100644 --- a/src/frontend/protobuf.cc +++ b/src/frontend/protobuf.cc @@ -261,7 +261,7 @@ void ExportProtobufModel(const char* filename, const Model& model) { << "The length of leaf vector must be identical to the " << "number of output groups"; for (tl_float e : leaf_vector) { - proto_node->add_leaf_vector(static_cast(e)); + proto_node->add_leaf_vector(static_cast(e)); } CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size()); } else { // leaf node with scalar output @@ -270,7 +270,7 @@ void ExportProtobufModel(const char* filename, const Model& model) { << "a leaf vector, *no other* leaf node can use a leaf vector"; flag_leaf_vector = 0; // now no leaf can use leaf vector - proto_node->set_leaf_value(static_cast(tree[nid].leaf_value())); + proto_node->set_leaf_value(static_cast(tree[nid].leaf_value())); } } else if (tree[nid].split_type() == SplitFeatureType::kNumerical) { // numerical split @@ -283,7 +283,7 @@ void ExportProtobufModel(const char* filename, const Model& model) { proto_node->set_split_index(static_cast(split_index)); proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL); proto_node->set_op(OpName(op)); - proto_node->set_threshold(threshold); + proto_node->set_threshold(static_cast(threshold)); Q.push({tree[nid].cleft(), proto_node->mutable_left_child()}); Q.push({tree[nid].cright(), proto_node->mutable_right_child()}); } else { // categorical split diff --git a/src/tree.proto b/src/tree.proto index 2104c918..e8030ad8 100644 --- a/src/tree.proto +++ b/src/tree.proto @@ -37,16 +37,16 @@ message Node { // expression evaluates to true; the right // child is taken otherwise. // missing if leaf or categorical split - optional float threshold = 7; // Decision threshold + optional double threshold = 7; // Decision threshold // missing if leaf or categorical split repeated uint32 left_categories = 8; // List of all categories belonging to // the left child. All other categories // will belong to the right child. // missing if leaf or numerical split - optional float leaf_value = 9; // Leaf value; missing if non-leaf + optional double leaf_value = 9; // Leaf value; missing if non-leaf // also missing if leaf_vector field exists - repeated float leaf_vector = 10; // Usually missing; only used for random + repeated double leaf_vector = 10; // Usually missing; only used for random // forests with multi-class classification optional uint64 data_count = 11; // number of data points whose traversal // paths include this node. May be diff --git a/tests/python/test_model_builder.py b/tests/python/test_model_builder.py index 82006ec3..f39e8064 100644 --- a/tests/python/test_model_builder.py +++ b/tests/python/test_model_builder.py @@ -1377,7 +1377,7 @@ def test_model_builder2(self): def test_model_builder3(self): """Test programmatic model construction using scikit-learn random forest""" X, y = load_iris(return_X_y=True) - clf = RandomForestClassifier(max_depth=3, random_state=0) + clf = RandomForestClassifier(max_depth=3, random_state=0, n_estimators=10) clf.fit(X, y) expected_prob = clf.predict_proba(X) @@ -1425,6 +1425,17 @@ def process_tree(sklearn_tree): out_prob = predictor.predict(batch) assert_almost_equal(out_prob, expected_prob) + # Test round-trip with Protobuf + model.export_protobuf('./my.buffer') + model = treelite.Model.load('./my.buffer', 'protobuf') + for toolchain in os_compatible_toolchains(): + model.export_lib(toolchain=toolchain, libpath=libpath, + params={'annotate_in': './annotation.json'}, verbose=True) + predictor = treelite.runtime.Predictor(libpath=libpath, verbose=True) + batch = treelite.runtime.Batch.from_npy2d(X) + out_prob = predictor.predict(batch) + assert_almost_equal(out_prob, expected_prob) + def test_node_insert_delete(self): """Test ability to add and remove nodes""" builder = treelite.ModelBuilder(num_feature=3)