Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Double input type for RandomForest and ZipMap #826

Merged
merged 3 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
class OutlierMixin:
pass

from sklearn.ensemble import IsolationForest, RandomTreesEmbedding
from sklearn.ensemble import (
IsolationForest, RandomTreesEmbedding, RandomForestClassifier)
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.linear_model import BayesianRidge
from sklearn.model_selection import GridSearchCV
Expand Down Expand Up @@ -47,7 +48,8 @@ class OutlierMixin:
from .common._topology import Topology, Variable
from .common.data_types import (
DictionaryType, Int64TensorType, SequenceType,
StringTensorType, TensorType, guess_tensor_type)
StringTensorType, TensorType, FloatTensorType,
guess_tensor_type)
from .common.utils import get_column_indices
from .common.utils_checking import check_signature
from .common.utils_classifier import get_label_classes
Expand Down Expand Up @@ -153,8 +155,12 @@ def _parse_sklearn_simple_model(scope, model, inputs, custom_parsers=None,
# be fixed in shape inference phase.
label_variable = scope.declare_local_variable(
'label', Int64TensorType())
if type(model) in [RandomForestClassifier]:
prob_dtype = FloatTensorType()
else:
prob_dtype = guess_tensor_type(inputs[0].type)
probability_tensor_variable = scope.declare_local_variable(
'probabilities', guess_tensor_type(inputs[0].type))
'probabilities', prob_dtype)
this_operator.outputs.append(label_variable)
this_operator.outputs.append(probability_tensor_variable)

Expand Down
5 changes: 3 additions & 2 deletions skl2onnx/operator_converters/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def convert_sklearn_random_forest_classifier(
dtype = guess_numpy_type(operator.inputs[0].type)
if dtype != np.float64:
dtype = np.float32
attr_dtype = dtype if op_version >= 3 else np.float32
op = operator.raw_operator

if hasattr(op, 'n_outputs_'):
Expand Down Expand Up @@ -209,7 +210,7 @@ def convert_sklearn_random_forest_classifier(
'target_weights', 'nodes_hitrates',
'base_values'):
attr_pairs[k] = np.array(
attr_pairs[k], dtype=dtype).ravel()
attr_pairs[k], dtype=attr_dtype).ravel()

container.add_node(
op_type, input_name,
Expand Down Expand Up @@ -249,7 +250,7 @@ def convert_sklearn_random_forest_classifier(
if k in ('nodes_values', 'class_weights',
'target_weights', 'nodes_hitrates',
'base_values'):
attrs[k] = np.array(attrs[k], dtype=dtype).ravel()
attrs[k] = np.array(attrs[k], dtype=attr_dtype).ravel()

if options['decision_path']:
# decision_path
Expand Down
3 changes: 3 additions & 0 deletions skl2onnx/shape_calculators/zip_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def calculate_sklearn_zipmap(operator):
if len(operator.inputs) == 2:
operator.outputs[0].type = operator.inputs[0].type.__class__(
operator.inputs[0].type.shape)
if operator.outputs[1].type is not None:
operator.outputs[1].type.element_type.value_type = \
operator.inputs[1].type.__class__([])


def calculate_sklearn_zipmap_columns(operator):
Expand Down
20 changes: 19 additions & 1 deletion tests/test_sklearn_random_forest_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from skl2onnx.common.data_types import (
BooleanTensorType,
FloatTensorType,
Int64TensorType)
DoubleTensorType,
Int64TensorType,
)
from skl2onnx import convert_sklearn, to_onnx
from test_utils import (
binary_array_to_string,
Expand Down Expand Up @@ -252,6 +254,22 @@ def test_extra_trees_classifier_bool(self):
X, model, model_onnx,
basename="SklearnExtraTreesClassifierBool")

@ignore_warnings(category=FutureWarning)
def test_random_forest_classifier_double(self):
model, X = fit_classification_model(
RandomForestClassifier(n_estimators=5, random_state=42),
3, is_double=True)
for opv in [1, 2, 3]:
model_onnx = convert_sklearn(
model, "random forest classifier",
[("input", DoubleTensorType([None, X.shape[1]]))],
target_opset={'ai.onnx.ml': opv,
'': TARGET_OPSET})
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X, model, model_onnx,
basename="SklearnRandomForestClassifierDouble")

@ignore_warnings(category=FutureWarning)
def common_test_model_hgb_regressor(self, add_nan=False):
model = HistGradientBoostingRegressor(max_iter=5, max_depth=2)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def fit_classification_model(model, n_classes, is_int=False,
pos_features=False, label_string=False,
random_state=42, is_bool=False,
n_features=20, n_redundant=None,
n_repeated=None, cls_dtype=None):
n_repeated=None, cls_dtype=None,
is_double=False,):
X, y = make_classification(
n_classes=n_classes, n_features=n_features, n_samples=250,
random_state=random_state, n_informative=min(7, n_features),
Expand All @@ -62,6 +63,7 @@ def fit_classification_model(model, n_classes, is_int=False,
if label_string:
y = numpy.array(['cl%d' % cl for cl in y])
X = X.astype(numpy.int64) if is_int or is_bool else X.astype(numpy.float32)
X = X.astype(numpy.double) if is_double else X
if pos_features:
X = numpy.abs(X)
if is_bool:
Expand Down