diff --git a/.azure-pipelines/linux-conda-CI.yml b/.azure-pipelines/linux-conda-CI.yml index 8e049cae6..5d72b4934 100644 --- a/.azure-pipelines/linux-conda-CI.yml +++ b/.azure-pipelines/linux-conda-CI.yml @@ -257,6 +257,8 @@ jobs: echo "---------------" pip show numpy echo "---------------" + pip show scipy + echo "---------------" pip show pandas echo "---------------" pip show onnx diff --git a/skl2onnx/_parse.py b/skl2onnx/_parse.py index 28ce893e0..a9f671954 100644 --- a/skl2onnx/_parse.py +++ b/skl2onnx/_parse.py @@ -413,17 +413,9 @@ def _parse_sklearn_random_trees_embedding(scope, model, inputs, return res -def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None): - options = scope.get_options(model, dict(zipmap=True)) - no_zipmap = ( - (isinstance(options['zipmap'], bool) and not options['zipmap']) or - (model.__class__ in [NuSVC, SVC] and not model.probability)) - probability_tensor = _parse_sklearn_simple_model( - scope, model, inputs, custom_parsers=custom_parsers) - if no_zipmap: - return probability_tensor - - if options['zipmap'] == 'columns': +def _apply_zipmap(zipmap_options, scope, model, input_type, + probability_tensor): + if zipmap_options == 'columns': zipmap_operator = scope.declare_local_operator('SklearnZipMapColumns') classes = get_label_classes(scope, model) classes_names = get_label_classes(scope, model, node_names=True) @@ -456,10 +448,11 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None): label_type = StringTensorType([None]) zip_label = scope.declare_local_variable('output_label', label_type) - zipmap_operator.outputs.append(zip_label) + if len(probability_tensor) == 2: + zipmap_operator.outputs.append(zip_label) - if options['zipmap'] == 'columns': - prob_type = probability_tensor[1].type + if zipmap_options == 'columns': + prob_type = probability_tensor[-1].type for cl in classes_names: output_cl = scope.declare_local_variable(cl, prob_type.__class__()) zipmap_operator.outputs.append(output_cl) @@ -468,24 +461,84 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None): 'output_probability', SequenceType( DictionaryType( - label_type, guess_tensor_type(inputs[0].type)))) + label_type, guess_tensor_type(input_type)))) zipmap_operator.outputs.append(zip_probability) zipmap_operator.init_status(is_evaluated=True) return zipmap_operator.outputs +def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None): + options = scope.get_options(model, dict(zipmap=True)) + no_zipmap = ( + (isinstance(options['zipmap'], bool) and not options['zipmap']) or + (model.__class__ in [NuSVC, SVC] and not model.probability)) + probability_tensor = _parse_sklearn_simple_model( + scope, model, inputs, custom_parsers=custom_parsers) + if no_zipmap: + return probability_tensor + return _apply_zipmap( + options['zipmap'], scope, model, inputs[0].type, probability_tensor) + + def _parse_sklearn_multi_output_classifier(scope, model, inputs, custom_parsers=None): alias = _get_sklearn_operator_name(type(model)) this_operator = scope.declare_local_operator(alias, model) this_operator.inputs = inputs - label = scope.declare_local_variable("label", Int64TensorType()) - proba = scope.declare_local_variable( - "probabilities", SequenceType(guess_tensor_type(inputs[0].type))) + guessed_output_type = guess_tensor_type(inputs[0].type) + label_type = Int64TensorType() + label = scope.declare_local_variable("label", label_type) + + options = scope.get_options(model, dict(zipmap=True)) + no_zipmap = isinstance(options['zipmap'], bool) and not options['zipmap'] + + if no_zipmap: + proba = scope.declare_local_variable( + "probabilities", SequenceType(guessed_output_type)) + this_operator.outputs.append(label) + this_operator.outputs.append(proba) + return this_operator.outputs + + warnings.warn( + "The current converter for class %r " + "creates a sequence of maps as an output." + "This is not allowed in standards ONNX specifications." + "" % model.__class__.__name__, + RuntimeWarning) + this_operator.outputs.append(label) + proba = scope.declare_local_variable( + "output_probabilities", SequenceType(guessed_output_type)) this_operator.outputs.append(proba) - return this_operator.outputs + + zipmap = options.get('zipmap', False) + if zipmap == 'columns': + raise RuntimeError( + "Unable to convert model %r with option zipmap=%r." % ( + model.__class__.__name__, zipmap)) + + # zipmap is True + zipped = [] + for i in range(len(model.estimators_)): + output_name = scope.declare_local_variable( + "multi_output_%d" % i, guessed_output_type) + sequence_at = scope.declare_local_operator('SklearnSequenceAt') + sequence_at.inputs.append(proba) + sequence_at.outputs.append(output_name) + sequence_at.index = i + zipped.extend( + _apply_zipmap( + zipmap, scope, model.estimators_[i], + guessed_output_type, sequence_at.outputs)) + + sequence_build = scope.declare_local_operator('SklearnSequenceConstruct') + sequence_build.inputs.extend(zipped) + proba_zipped = scope.declare_local_variable( + "probabilities", + SequenceType(DictionaryType(label_type, guessed_output_type))) + sequence_build.outputs.append(proba_zipped) + return sequence_build.outputs def _parse_sklearn_gaussian_process(scope, model, inputs, custom_parsers=None): diff --git a/skl2onnx/operator_converters/__init__.py b/skl2onnx/operator_converters/__init__.py index da1a7356e..2110f2dfe 100644 --- a/skl2onnx/operator_converters/__init__.py +++ b/skl2onnx/operator_converters/__init__.py @@ -50,6 +50,7 @@ from . import ransac_regressor from . import replace_op from . import scaler_op +from . import sequence from . import sgd_classifier from . import stacking from . import support_vector_machines @@ -108,6 +109,7 @@ ransac_regressor, replace_op, scaler_op, + sequence, sgd_classifier, stacking, support_vector_machines, diff --git a/skl2onnx/operator_converters/multioutput.py b/skl2onnx/operator_converters/multioutput.py index 79c080311..2153d8aaa 100644 --- a/skl2onnx/operator_converters/multioutput.py +++ b/skl2onnx/operator_converters/multioutput.py @@ -61,13 +61,12 @@ def convert_multi_output_classifier_converter( op_version=op_version) for y in y_list] - label = OnnxConcat(*label_list, axis=1, op_version=op_version, - output_names=[operator.outputs[0]]) - label.add_to(scope=scope, container=container) - # probabilities proba_list = [OnnxIdentity(y[1], op_version=op_version) for y in y_list] + label = OnnxConcat(*label_list, axis=1, op_version=op_version, + output_names=[operator.outputs[0]]) + label.add_to(scope=scope, container=container) proba = OnnxSequenceConstruct( *proba_list, op_version=op_version, @@ -79,4 +78,5 @@ def convert_multi_output_classifier_converter( convert_multi_output_regressor_converter) register_converter('SklearnMultiOutputClassifier', convert_multi_output_classifier_converter, - options={'nocl': [False, True]}) + options={'nocl': [False, True], + 'zipmap': [False, True]}) diff --git a/skl2onnx/operator_converters/sequence.py b/skl2onnx/operator_converters/sequence.py new file mode 100644 index 000000000..77b937016 --- /dev/null +++ b/skl2onnx/operator_converters/sequence.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +from ..proto import onnx_proto +from ..common._registration import register_converter +from ..common._topology import Scope, Operator +from ..common._container import ModelComponentContainer + + +def convert_sklearn_sequence_at(scope: Scope, operator: Operator, + container: ModelComponentContainer): + i_index = operator.index + index_name = scope.get_unique_variable_name("seq_at%d" % i_index) + container.add_initializer( + index_name, onnx_proto.TensorProto.INT64, [], [i_index]) + container.add_node( + 'SequenceAt', [operator.inputs[0].full_name, index_name], + operator.outputs[0].full_name, + name=scope.get_unique_operator_name('SequenceAt%d' % i_index)) + + +def convert_sklearn_sequence_construct(scope: Scope, operator: Operator, + container: ModelComponentContainer): + container.add_node( + 'SequenceConstruct', [i.full_name for i in operator.inputs], + operator.outputs[0].full_name, + name=scope.get_unique_operator_name('SequenceConstruct')) + + +register_converter('SklearnSequenceAt', convert_sklearn_sequence_at) +register_converter( + 'SklearnSequenceConstruct', convert_sklearn_sequence_construct) diff --git a/skl2onnx/operator_converters/zip_map.py b/skl2onnx/operator_converters/zip_map.py index d3b7bfb2d..84abbc19d 100644 --- a/skl2onnx/operator_converters/zip_map.py +++ b/skl2onnx/operator_converters/zip_map.py @@ -30,9 +30,24 @@ def _common_convert_sklearn_zipmap(scope: Scope, operator: Operator, def convert_sklearn_zipmap(scope: Scope, operator: Operator, container: ModelComponentContainer): - zipmap_attrs = _common_convert_sklearn_zipmap(scope, operator, container) - container.add_node('ZipMap', operator.inputs[1].full_name, - operator.outputs[1].full_name, + if len(operator.inputs) == 2: + zipmap_attrs = _common_convert_sklearn_zipmap( + scope, operator, container) + container.add_node('ZipMap', operator.inputs[1].full_name, + operator.outputs[1].full_name, + op_domain='ai.onnx.ml', **zipmap_attrs) + return + + if hasattr(operator, 'classlabels_int64s'): + zipmap_attrs = dict(classlabels_int64s=operator.classlabels_int64s) + elif hasattr(operator, 'classlabels_strings'): + zipmap_attrs = dict(classlabels_strings=operator.classlabels_strings) + else: + raise RuntimeError( + "operator should have attribute 'classlabels_int64s' or " + "'classlabels_strings'.") + container.add_node('ZipMap', operator.inputs[0].full_name, + operator.outputs[0].full_name, op_domain='ai.onnx.ml', **zipmap_attrs) diff --git a/skl2onnx/shape_calculators/__init__.py b/skl2onnx/shape_calculators/__init__.py index 43a366db8..9c58fe126 100644 --- a/skl2onnx/shape_calculators/__init__.py +++ b/skl2onnx/shape_calculators/__init__.py @@ -37,6 +37,7 @@ from . import random_trees_embedding from . import replace_op from . import scaler +from . import sequence from . import svd from . import support_vector_machines from . import text_vectorizer @@ -80,6 +81,7 @@ random_trees_embedding, replace_op, scaler, + sequence, svd, support_vector_machines, text_vectorizer, diff --git a/skl2onnx/shape_calculators/sequence.py b/skl2onnx/shape_calculators/sequence.py new file mode 100644 index 000000000..4ef12cea2 --- /dev/null +++ b/skl2onnx/shape_calculators/sequence.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from ..common._registration import register_shape_calculator + + +def calculate_sklearn_sequence_at(operator): + pass + + +def calculate_sklearn_sequence_construct(operator): + pass + + +register_shape_calculator('SklearnSequenceAt', calculate_sklearn_sequence_at) +register_shape_calculator( + 'SklearnSequenceConstruct', calculate_sklearn_sequence_construct) diff --git a/skl2onnx/shape_calculators/zip_map.py b/skl2onnx/shape_calculators/zip_map.py index 25ac5eb85..d9c9121bb 100644 --- a/skl2onnx/shape_calculators/zip_map.py +++ b/skl2onnx/shape_calculators/zip_map.py @@ -2,13 +2,16 @@ from ..common._registration import register_shape_calculator -from ..common.utils import check_input_and_output_numbers def calculate_sklearn_zipmap(operator): - check_input_and_output_numbers(operator, output_count_range=2) - operator.outputs[0].type = operator.inputs[0].type.__class__( - operator.inputs[0].type.shape) + if (len(operator.inputs) != len(operator.outputs) or + len(operator.inputs) not in (1, 2)): + raise RuntimeError( + "SklearnZipMap expects the same number of inputs and outputs.") + if len(operator.inputs) == 2: + operator.outputs[0].type = operator.inputs[0].type.__class__( + operator.inputs[0].type.shape) def calculate_sklearn_zipmap_columns(operator): diff --git a/tests/test_algebra_cascade.py b/tests/test_algebra_cascade.py index 5c3cd5637..250b2e83b 100644 --- a/tests/test_algebra_cascade.py +++ b/tests/test_algebra_cascade.py @@ -214,35 +214,36 @@ def test_model_mlp_regressor_default(self): for opv in (1, 2, 7, 8, 9, 10, 11, 12, 13, onnx_opset_version()): if opv is not None and opv > TARGET_OPSET: continue - try: - onx = convert_sklearn( - model, "scikit-learn MLPRegressor", - [("input", FloatTensorType([None, X_test.shape[1]]))], - target_opset=opv) - except RuntimeError as e: - if ("is higher than the number of the " - "installed onnx package") in str(e): - continue - raise e - as_string = onx.SerializeToString() - try: - ort = InferenceSession(as_string) - except (RuntimeError, InvalidGraph, Fail) as e: - if opv in (None, 1, 2): - continue - if opv >= onnx_opset_version(): - continue - if ("No suitable kernel definition found for " - "op Cast(9)") in str(e): - # too old onnxruntime - continue - raise AssertionError( - "Unable to load opv={}\n---\n{}\n---".format( - opv, onx)) from e - res_out = ort.run(None, {'input': X_test}) - assert len(res_out) == 1 - res = res_out[0] - assert_almost_equal(exp.ravel(), res.ravel(), decimal=4) + with self.subTest(opv=opv): + try: + onx = convert_sklearn( + model, "scikit-learn MLPRegressor", + [("input", FloatTensorType([None, X_test.shape[1]]))], + target_opset=opv) + except RuntimeError as e: + if ("is higher than the number of the " + "installed onnx package") in str(e): + continue + raise e + as_string = onx.SerializeToString() + try: + ort = InferenceSession(as_string) + except (RuntimeError, InvalidGraph, Fail) as e: + if opv in (None, 1, 2): + continue + if opv >= onnx_opset_version(): + continue + if ("No suitable kernel definition found for " + "op Cast(9)") in str(e): + # too old onnxruntime + continue + raise AssertionError( + "Unable to load opv={}\n---\n{}\n---".format( + opv, onx)) from e + res_out = ort.run(None, {'input': X_test}) + assert len(res_out) == 1 + res = res_out[0] + assert_almost_equal(exp.ravel(), res.ravel(), decimal=4) if __name__ == "__main__": diff --git a/tests/test_convert_options.py b/tests/test_convert_options.py new file mode 100644 index 000000000..07839d751 --- /dev/null +++ b/tests/test_convert_options.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from distutils.version import StrictVersion +import numpy +from numpy.testing import assert_almost_equal +from pandas import DataFrame +from onnxruntime import InferenceSession +from sklearn import __version__ as sklver +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.multioutput import MultiOutputClassifier +from sklearn.tree import DecisionTreeClassifier +from skl2onnx import to_onnx +try: + from sklearn.utils._testing import ignore_warnings +except ImportError: + from sklearn.utils.testing import ignore_warnings +from sklearn.exceptions import ConvergenceWarning +from test_utils import TARGET_OPSET + + +sklver = '.'.join(sklver.split('.')[:2]) + + +class TestConvertOptions(unittest.TestCase): + + @staticmethod + def get_model_classifiers(): + models = [ + # BaggingClassifier, + # BernoulliNB, + # CategoricalNB, + # CalibratedClassifierCV, + # ComplementNB, + DecisionTreeClassifier(max_depth=2), + # ExtraTreeClassifier, + # ExtraTreesClassifier, + # GaussianNB, + # GaussianProcessClassifier, + # GradientBoostingClassifier, + # HistGradientBoostingClassifier, + # KNeighborsClassifier, + # LinearDiscriminantAnalysis, + # LinearSVC, + LogisticRegression(max_iter=10), + # LogisticRegressionCV, + # MLPClassifier, + # MultinomialNB, + # NuSVC, + # OneVsRestClassifier, + # PassiveAggressiveClassifier, + # Perceptron, + # RandomForestClassifier, + # SGDClassifier, + # StackingClassifier, + # SVC, + # VotingClassifier, + ] + return models + + @staticmethod + def dict_to_array(proba_as_dict): + df = DataFrame(proba_as_dict) + return df.values + + @staticmethod + def almost_equal(expected_label, expected_proba, label, *probas, + zipmap=False, decimal=5): + assert_almost_equal(expected_label, label) + if zipmap == 'columns': + for row, pr in zip(expected_proba.T, probas): + assert_almost_equal( + row.ravel(), pr.ravel(), decimal=decimal) + + elif zipmap: + proba = probas[0] + assert_almost_equal( + expected_proba, + TestConvertOptions.dict_to_array(proba), + decimal=decimal) + else: + proba = probas[0] + assert_almost_equal(expected_proba, proba, decimal=decimal) + + @unittest.skipIf(StrictVersion(sklver) < StrictVersion("0.24"), + reason="known issue with string") + @ignore_warnings(category=(FutureWarning, ConvergenceWarning, + DeprecationWarning)) + def test_classifier_option_zipmap(self): + data = load_iris() + X, y = data.data, data.target + X = X.astype(numpy.float32) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + for zipmap in [False, True, 'columns']: + for cls in TestConvertOptions.get_model_classifiers(): + with self.subTest(cls=cls.__class__.__name__, zipmap=zipmap): + cls.fit(X_train, y_train) + expected_label = cls.predict(X_test) + expected_proba = cls.predict_proba(X_test) + + onx = to_onnx(cls, X[:1], options={'zipmap': zipmap}, + target_opset=TARGET_OPSET) + sess = InferenceSession(onx.SerializeToString()) + got = sess.run(None, {'X': X_test}) + TestConvertOptions.almost_equal( + expected_label, expected_proba, *got, zipmap=zipmap) + + onx = to_onnx( + cls, X[:1], + options={cls.__class__: {'zipmap': zipmap}}, + target_opset=TARGET_OPSET) + sess = InferenceSession(onx.SerializeToString()) + got = sess.run(None, {'X': X_test}) + TestConvertOptions.almost_equal( + expected_label, expected_proba, *got, zipmap=zipmap) + + onx = to_onnx( + cls, X[:1], + options={id(cls): {'zipmap': zipmap}}, + target_opset=TARGET_OPSET) + sess = InferenceSession(onx.SerializeToString()) + got = sess.run(None, {'X': X_test}) + TestConvertOptions.almost_equal( + expected_label, expected_proba, *got, zipmap=zipmap) + + @staticmethod + def get_model_multi_label(): + models = [ + MultiOutputClassifier(DecisionTreeClassifier(max_depth=2)), + ] + return models + + @staticmethod + def almost_equal_multi(expected_label, expected_proba, label, *probas, + zipmap=False, decimal=5): + assert_almost_equal(expected_label, label) + if zipmap == 'columns': + for row, pr in zip(expected_proba.T, probas): + assert_almost_equal( + row.ravel(), pr.ravel(), decimal=decimal) + + elif zipmap: + for expected, proba in zip(expected_proba, probas): + assert_almost_equal( + expected_proba, + TestConvertOptions.dict_to_array(proba), + decimal=decimal) + else: + proba = probas[0] + assert_almost_equal(expected_proba, proba, decimal=decimal) + + @unittest.skipIf(StrictVersion(sklver) < StrictVersion("0.24"), + reason="known issue with string") + @ignore_warnings(category=(FutureWarning, ConvergenceWarning, + DeprecationWarning)) + def test_multi_label_option_zipmap(self): + data = load_iris() + X, y = data.data, data.target + X = X.astype(numpy.float32) + y = numpy.vstack([y, 1 - y]).T + y[0, :] = 1 + X_train, X_test, y_train, y_test = train_test_split(X, y) + + for zipmap in [False, True, 'columns']: + for cls in TestConvertOptions.get_model_multi_label(): + with self.subTest(cls=cls.__class__, zipmap=zipmap): + cls.fit(X_train, y_train) + expected_label = cls.predict(X_test) + expected_proba = cls.predict_proba(X_test) + + if zipmap == 'columns': + # Not implemented. + with self.assertRaises(ValueError): + to_onnx(cls, X[:1], options={'zipmap': zipmap}, + target_opset=TARGET_OPSET) + continue + + onx = to_onnx(cls, X[:1], options={'zipmap': zipmap}, + target_opset=TARGET_OPSET) + + if zipmap: + # The converter works but SequenceConstruct + # does not support Sequence of Maps. + continue + + sess = InferenceSession(onx.SerializeToString()) + got = sess.run(None, {'X': X_test}) + TestConvertOptions.almost_equal_multi( + expected_label, expected_proba, *got, zipmap=zipmap) + + onx = to_onnx( + cls, X[:1], + options={cls.__class__: {'zipmap': zipmap}}, + target_opset=TARGET_OPSET) + sess = InferenceSession(onx.SerializeToString()) + got = sess.run(None, {'X': X_test}) + assert_almost_equal(expected_label, got[0]) + + onx = to_onnx( + cls, X[:1], + options={id(cls): {'zipmap': zipmap}}, + target_opset=TARGET_OPSET) + sess = InferenceSession(onx.SerializeToString()) + got = sess.run(None, {'X': X_test}) + assert_almost_equal(expected_label, got[0]) + + +if __name__ == "__main__": + # import logging + # log = logging.getLogger('skl2onnx') + # log.setLevel(logging.DEBUG) + # logging.basicConfig(level=logging.DEBUG) + # TestConvertOptions().test_multi_label_option_zipmap() + unittest.main() diff --git a/tests/test_sklearn_custom_nmf.py b/tests/test_sklearn_custom_nmf.py index 3961ed138..e873e7a64 100644 --- a/tests/test_sklearn_custom_nmf.py +++ b/tests/test_sklearn_custom_nmf.py @@ -21,7 +21,7 @@ class TestSklearnCustomNMF(unittest.TestCase): def test_custom_nmf(self): mat = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], - [1, 0, 0, 0], [0, 0, 0, 1]], dtype=np.float64) + [1, 0, 0, 0], [0, 0, 1, 0]], dtype=np.float64) mat[:mat.shape[1], :] += np.identity(mat.shape[1]) mod = NMF(n_components=2, max_iter=2) diff --git a/tests/test_sklearn_multi_output.py b/tests/test_sklearn_multi_output.py index f8e0025cb..fb95a06b1 100644 --- a/tests/test_sklearn_multi_output.py +++ b/tests/test_sklearn_multi_output.py @@ -8,6 +8,10 @@ from sklearn.datasets import load_linnerud, make_multilabel_classification from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier from sklearn.linear_model import Ridge, LogisticRegression +try: + from sklearn.utils._testing import ignore_warnings +except ImportError: + from sklearn.utils.testing import ignore_warnings from skl2onnx import to_onnx from test_utils import dump_data_and_model, TARGET_OPSET @@ -33,14 +37,14 @@ def test_multi_output_regressor(self): @unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available.") + @ignore_warnings(category=(FutureWarning, + DeprecationWarning)) def test_multi_output_classifier(self): X, y = make_multilabel_classification(n_classes=3, random_state=0) X = X.astype(numpy.float32) clf = MultiOutputClassifier(LogisticRegression()).fit(X, y) - with self.assertRaises(NameError): - to_onnx(clf, X[:1], target_opset=TARGET_OPSET, - options={id(clf): {'zipmap': False}}) - onx = to_onnx(clf, X[:1], target_opset=TARGET_OPSET) + onx = to_onnx(clf, X[:1], target_opset=TARGET_OPSET, + options={'zipmap': False}) self.assertNotIn("ZipMap", str(onx)) sess = InferenceSession(onx.SerializeToString()) @@ -54,7 +58,7 @@ def test_multi_output_classifier(self): # check option nocl=True onx = to_onnx(clf, X[:1], target_opset=TARGET_OPSET, - options={id(clf): {'nocl': True}}) + options={id(clf): {'nocl': True, 'zipmap': False}}) self.assertNotIn("ZipMap", str(onx)) sess = InferenceSession(onx.SerializeToString()) @@ -68,7 +72,7 @@ def test_multi_output_classifier(self): # check option nocl=False onx = to_onnx(clf, X[:1], target_opset=TARGET_OPSET, - options={id(clf): {'nocl': False}}) + options={id(clf): {'nocl': False, 'zipmap': False}}) self.assertNotIn("ZipMap", str(onx)) sess = InferenceSession(onx.SerializeToString()) diff --git a/tests/test_sklearn_pipeline.py b/tests/test_sklearn_pipeline.py index d58674064..989b55155 100644 --- a/tests/test_sklearn_pipeline.py +++ b/tests/test_sklearn_pipeline.py @@ -52,6 +52,9 @@ from onnxruntime import __version__ as ort_version, InferenceSession +ort_version = ".".join(ort_version.split('.')[:2]) + + def check_scikit_version(): # StrictVersion does not work with development versions vers = '.'.join(sklearn_version.split('.')[:2]) @@ -699,7 +702,8 @@ def test_pipeline_pipeline_rf(self): ('A', StringTensorType([None, 1])), ('B', StringTensorType([None, 1])), ('TEXT', StringTensorType([None, 1]))], - target_opset=TARGET_OPSET) + target_opset=TARGET_OPSET, + options={MultiOutputClassifier: {'zipmap': False}}) # with open("debug.onnx", "wb") as f: # f.write(model_onnx.SerializeToString()) sess = InferenceSession(model_onnx.SerializeToString()) @@ -743,7 +747,8 @@ def test_issue_712_multio(self): inputs = {'CAT1': dfx['CAT1'].values.reshape((-1, 1)), 'CAT2': dfx['CAT2'].values.reshape((-1, 1)), 'TEXT': dfx['TEXT'].values.reshape((-1, 1))} - onx = to_onnx(rf_clf, dfx, target_opset=TARGET_OPSET) + onx = to_onnx(rf_clf, dfx, target_opset=TARGET_OPSET, + options={MultiOutputClassifier: {'zipmap': False}}) sess = InferenceSession(onx.SerializeToString()) got = sess.run(None, inputs) @@ -799,7 +804,9 @@ def test_issue_712_svc_multio(self): inputs = {'CAT1': dfx['CAT1'].values.reshape((-1, 1)), 'CAT2': dfx['CAT2'].values.reshape((-1, 1)), 'TEXT': dfx['TEXT'].values.reshape((-1, 1))} - onx = to_onnx(rf_clf, dfx, target_opset=TARGET_OPSET) + onx = to_onnx( + rf_clf, dfx, target_opset=TARGET_OPSET, + options={MultiOutputClassifier: {'zipmap': False}}) sess = InferenceSession(onx.SerializeToString()) got = sess.run(None, inputs) assert_almost_equal(expected_label, got[0])