Skip to content

Commit

Permalink
Enable Double input type for RandomForest and ZipMap (#826)
Browse files Browse the repository at this point in the history
* Enable Double input type for RandomForest and ZipMap

Signed-off-by: BowenBao <bowbao@microsoft.com>

* Adjust attr type based on opset; revert regressor changes

Signed-off-by: BowenBao <bowbao@microsoft.com>

* flake8

Signed-off-by: BowenBao <bowbao@microsoft.com>
  • Loading branch information
BowenBao authored Feb 23, 2022
1 parent d2443ca commit e9433de
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 7 deletions.
12 changes: 9 additions & 3 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
class OutlierMixin:
pass

from sklearn.ensemble import IsolationForest, RandomTreesEmbedding
from sklearn.ensemble import (
IsolationForest, RandomTreesEmbedding, RandomForestClassifier)
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.linear_model import BayesianRidge
from sklearn.model_selection import GridSearchCV
Expand Down Expand Up @@ -47,7 +48,8 @@ class OutlierMixin:
from .common._topology import Topology, Variable
from .common.data_types import (
DictionaryType, Int64TensorType, SequenceType,
StringTensorType, TensorType, guess_tensor_type)
StringTensorType, TensorType, FloatTensorType,
guess_tensor_type)
from .common.utils import get_column_indices
from .common.utils_checking import check_signature
from .common.utils_classifier import get_label_classes
Expand Down Expand Up @@ -153,8 +155,12 @@ def _parse_sklearn_simple_model(scope, model, inputs, custom_parsers=None,
# be fixed in shape inference phase.
label_variable = scope.declare_local_variable(
'label', Int64TensorType())
if type(model) in [RandomForestClassifier]:
prob_dtype = FloatTensorType()
else:
prob_dtype = guess_tensor_type(inputs[0].type)
probability_tensor_variable = scope.declare_local_variable(
'probabilities', guess_tensor_type(inputs[0].type))
'probabilities', prob_dtype)
this_operator.outputs.append(label_variable)
this_operator.outputs.append(probability_tensor_variable)

Expand Down
5 changes: 3 additions & 2 deletions skl2onnx/operator_converters/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def convert_sklearn_random_forest_classifier(
dtype = guess_numpy_type(operator.inputs[0].type)
if dtype != np.float64:
dtype = np.float32
attr_dtype = dtype if op_version >= 3 else np.float32
op = operator.raw_operator

if hasattr(op, 'n_outputs_'):
Expand Down Expand Up @@ -209,7 +210,7 @@ def convert_sklearn_random_forest_classifier(
'target_weights', 'nodes_hitrates',
'base_values'):
attr_pairs[k] = np.array(
attr_pairs[k], dtype=dtype).ravel()
attr_pairs[k], dtype=attr_dtype).ravel()

container.add_node(
op_type, input_name,
Expand Down Expand Up @@ -249,7 +250,7 @@ def convert_sklearn_random_forest_classifier(
if k in ('nodes_values', 'class_weights',
'target_weights', 'nodes_hitrates',
'base_values'):
attrs[k] = np.array(attrs[k], dtype=dtype).ravel()
attrs[k] = np.array(attrs[k], dtype=attr_dtype).ravel()

if options['decision_path']:
# decision_path
Expand Down
3 changes: 3 additions & 0 deletions skl2onnx/shape_calculators/zip_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def calculate_sklearn_zipmap(operator):
if len(operator.inputs) == 2:
operator.outputs[0].type = operator.inputs[0].type.__class__(
operator.inputs[0].type.shape)
if operator.outputs[1].type is not None:
operator.outputs[1].type.element_type.value_type = \
operator.inputs[1].type.__class__([])


def calculate_sklearn_zipmap_columns(operator):
Expand Down
20 changes: 19 additions & 1 deletion tests/test_sklearn_random_forest_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from skl2onnx.common.data_types import (
BooleanTensorType,
FloatTensorType,
Int64TensorType)
DoubleTensorType,
Int64TensorType,
)
from skl2onnx import convert_sklearn, to_onnx
from test_utils import (
binary_array_to_string,
Expand Down Expand Up @@ -252,6 +254,22 @@ def test_extra_trees_classifier_bool(self):
X, model, model_onnx,
basename="SklearnExtraTreesClassifierBool")

@ignore_warnings(category=FutureWarning)
def test_random_forest_classifier_double(self):
model, X = fit_classification_model(
RandomForestClassifier(n_estimators=5, random_state=42),
3, is_double=True)
for opv in [1, 2, 3]:
model_onnx = convert_sklearn(
model, "random forest classifier",
[("input", DoubleTensorType([None, X.shape[1]]))],
target_opset={'ai.onnx.ml': opv,
'': TARGET_OPSET})
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X, model, model_onnx,
basename="SklearnRandomForestClassifierDouble")

@ignore_warnings(category=FutureWarning)
def common_test_model_hgb_regressor(self, add_nan=False):
model = HistGradientBoostingRegressor(max_iter=5, max_depth=2)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def fit_classification_model(model, n_classes, is_int=False,
pos_features=False, label_string=False,
random_state=42, is_bool=False,
n_features=20, n_redundant=None,
n_repeated=None, cls_dtype=None):
n_repeated=None, cls_dtype=None,
is_double=False,):
X, y = make_classification(
n_classes=n_classes, n_features=n_features, n_samples=250,
random_state=random_state, n_informative=min(7, n_features),
Expand All @@ -62,6 +63,7 @@ def fit_classification_model(model, n_classes, is_int=False,
if label_string:
y = numpy.array(['cl%d' % cl for cl in y])
X = X.astype(numpy.int64) if is_int or is_bool else X.astype(numpy.float32)
X = X.astype(numpy.double) if is_double else X
if pos_features:
X = numpy.abs(X)
if is_bool:
Expand Down

0 comments on commit e9433de

Please sign in to comment.