diff --git a/skl2onnx/_parse.py b/skl2onnx/_parse.py index 8a9e1ff6b..617747aea 100644 --- a/skl2onnx/_parse.py +++ b/skl2onnx/_parse.py @@ -13,7 +13,8 @@ class OutlierMixin: pass -from sklearn.ensemble import IsolationForest, RandomTreesEmbedding +from sklearn.ensemble import ( + IsolationForest, RandomTreesEmbedding, RandomForestClassifier) from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.linear_model import BayesianRidge from sklearn.model_selection import GridSearchCV @@ -47,7 +48,8 @@ class OutlierMixin: from .common._topology import Topology, Variable from .common.data_types import ( DictionaryType, Int64TensorType, SequenceType, - StringTensorType, TensorType, guess_tensor_type) + StringTensorType, TensorType, FloatTensorType, + guess_tensor_type) from .common.utils import get_column_indices from .common.utils_checking import check_signature from .common.utils_classifier import get_label_classes @@ -153,8 +155,12 @@ def _parse_sklearn_simple_model(scope, model, inputs, custom_parsers=None, # be fixed in shape inference phase. label_variable = scope.declare_local_variable( 'label', Int64TensorType()) + if type(model) in [RandomForestClassifier]: + prob_dtype = FloatTensorType() + else: + prob_dtype = guess_tensor_type(inputs[0].type) probability_tensor_variable = scope.declare_local_variable( - 'probabilities', guess_tensor_type(inputs[0].type)) + 'probabilities', prob_dtype) this_operator.outputs.append(label_variable) this_operator.outputs.append(probability_tensor_variable) diff --git a/skl2onnx/operator_converters/random_forest.py b/skl2onnx/operator_converters/random_forest.py index 72f93b315..ea3c9e821 100644 --- a/skl2onnx/operator_converters/random_forest.py +++ b/skl2onnx/operator_converters/random_forest.py @@ -92,6 +92,7 @@ def convert_sklearn_random_forest_classifier( dtype = guess_numpy_type(operator.inputs[0].type) if dtype != np.float64: dtype = np.float32 + attr_dtype = dtype if op_version >= 3 else np.float32 op = operator.raw_operator if hasattr(op, 'n_outputs_'): @@ -209,7 +210,7 @@ def convert_sklearn_random_forest_classifier( 'target_weights', 'nodes_hitrates', 'base_values'): attr_pairs[k] = np.array( - attr_pairs[k], dtype=dtype).ravel() + attr_pairs[k], dtype=attr_dtype).ravel() container.add_node( op_type, input_name, @@ -249,7 +250,7 @@ def convert_sklearn_random_forest_classifier( if k in ('nodes_values', 'class_weights', 'target_weights', 'nodes_hitrates', 'base_values'): - attrs[k] = np.array(attrs[k], dtype=dtype).ravel() + attrs[k] = np.array(attrs[k], dtype=attr_dtype).ravel() if options['decision_path']: # decision_path diff --git a/skl2onnx/shape_calculators/zip_map.py b/skl2onnx/shape_calculators/zip_map.py index d9c9121bb..9e0b7f8ed 100644 --- a/skl2onnx/shape_calculators/zip_map.py +++ b/skl2onnx/shape_calculators/zip_map.py @@ -12,6 +12,9 @@ def calculate_sklearn_zipmap(operator): if len(operator.inputs) == 2: operator.outputs[0].type = operator.inputs[0].type.__class__( operator.inputs[0].type.shape) + if operator.outputs[1].type is not None: + operator.outputs[1].type.element_type.value_type = \ + operator.inputs[1].type.__class__([]) def calculate_sklearn_zipmap_columns(operator): diff --git a/tests/test_sklearn_random_forest_converters.py b/tests/test_sklearn_random_forest_converters.py index a9927f4fd..606a1bc75 100644 --- a/tests/test_sklearn_random_forest_converters.py +++ b/tests/test_sklearn_random_forest_converters.py @@ -27,7 +27,9 @@ from skl2onnx.common.data_types import ( BooleanTensorType, FloatTensorType, - Int64TensorType) + DoubleTensorType, + Int64TensorType, +) from skl2onnx import convert_sklearn, to_onnx from test_utils import ( binary_array_to_string, @@ -252,6 +254,22 @@ def test_extra_trees_classifier_bool(self): X, model, model_onnx, basename="SklearnExtraTreesClassifierBool") + @ignore_warnings(category=FutureWarning) + def test_random_forest_classifier_double(self): + model, X = fit_classification_model( + RandomForestClassifier(n_estimators=5, random_state=42), + 3, is_double=True) + for opv in [1, 2, 3]: + model_onnx = convert_sklearn( + model, "random forest classifier", + [("input", DoubleTensorType([None, X.shape[1]]))], + target_opset={'ai.onnx.ml': opv, + '': TARGET_OPSET}) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, model, model_onnx, + basename="SklearnRandomForestClassifierDouble") + @ignore_warnings(category=FutureWarning) def common_test_model_hgb_regressor(self, add_nan=False): model = HistGradientBoostingRegressor(max_iter=5, max_depth=2) diff --git a/tests/test_utils/tests_helper.py b/tests/test_utils/tests_helper.py index d56850499..8fdf60732 100644 --- a/tests/test_utils/tests_helper.py +++ b/tests/test_utils/tests_helper.py @@ -51,7 +51,8 @@ def fit_classification_model(model, n_classes, is_int=False, pos_features=False, label_string=False, random_state=42, is_bool=False, n_features=20, n_redundant=None, - n_repeated=None, cls_dtype=None): + n_repeated=None, cls_dtype=None, + is_double=False,): X, y = make_classification( n_classes=n_classes, n_features=n_features, n_samples=250, random_state=random_state, n_informative=min(7, n_features), @@ -62,6 +63,7 @@ def fit_classification_model(model, n_classes, is_int=False, if label_string: y = numpy.array(['cl%d' % cl for cl in y]) X = X.astype(numpy.int64) if is_int or is_bool else X.astype(numpy.float32) + X = X.astype(numpy.double) if is_double else X if pos_features: X = numpy.abs(X) if is_bool: