diff --git a/CHANGELOGS.md b/CHANGELOGS.md index ed5de5ed5..fd10339ec 100644 --- a/CHANGELOGS.md +++ b/CHANGELOGS.md @@ -2,9 +2,12 @@ ## 1.17.0 (development) +* Fix the conversion of pipeline including pipelines, + issue [#1069](https://github.com/onnx/sklearn-onnx/pull/1069), + [#1072](https://github.com/onnx/sklearn-onnx/pull/1072) * Fix unexpected type for intercept in PoissonRegressor and GammaRegressor [#1070](https://github.com/onnx/sklearn-onnx/pull/1070) -* Add support for scikti-learn 1.4.0, +* Add support for scikit-learn 1.4.0, [#1058](https://github.com/onnx/sklearn-onnx/pull/1058), fixes issues [Many examples in the gallery are showing "broken"](https://github.com/onnx/sklearn-onnx/pull/1057), [TFIDF vectorizer target_opset issue](https://github.com/onnx/sklearn-onnx/pull/1055), diff --git a/skl2onnx/_parse.py b/skl2onnx/_parse.py index 78fdbaad9..b3c158fab 100644 --- a/skl2onnx/_parse.py +++ b/skl2onnx/_parse.py @@ -288,6 +288,12 @@ def _parse_sklearn_pipeline(scope, model, inputs, custom_parsers=None): :return: A list of output variables produced by the input pipeline """ for step in model.steps: + if ( + scope.options is None + or (id(step[1]) not in scope.options and type(step[1]) not in scope.options) + ) and is_classifier(step[1]): + options = scope.get_options(model, {"zipmap": True}) + scope.add_options(id(step[1]), options) inputs = _parse_sklearn(scope, step[1], inputs, custom_parsers=custom_parsers) return inputs @@ -543,7 +549,7 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None): if options.get("output_class_labels", False): raise RuntimeError( - "Option 'output_class_labels' is not compatible with option " "'zipmap'." + "Option 'output_class_labels' is not compatible with option 'zipmap'." ) return _apply_zipmap( diff --git a/skl2onnx/common/_container.py b/skl2onnx/common/_container.py index ce42e43b6..7245612e2 100644 --- a/skl2onnx/common/_container.py +++ b/skl2onnx/common/_container.py @@ -917,7 +917,7 @@ def ensure_topological_order(self): def nstr(name): if name in order: - return "%s#%d" % (name, order[name]) + return "%s -- #%d" % (name, order[name]) return name rows = [ @@ -937,17 +937,91 @@ def nstr(name): "%s|%s(%s) -> [%s]" % ( n.op_type, - n.name or n.op_type, + n.name or "", ", ".join(map(nstr, n.input)), ", ".join(n.output), ) for n in self.nodes ) + rows.append("--") + rows.append("--full-marked-nodes-with-output--") + rows.append("--") + set_name = set(order) + for n in self.nodes: + if set(n.input) & set_name == set(n.input) and set( + n.output + ) & set_name == set(n.output): + rows.append( + "%s|%s(%s) -> [%s]" + % ( + n.op_type, + n.name or "", + ", ".join(map(nstr, n.input)), + ", ".join(n.output), + ) + ) + rows.append("--") + rows.append("--full-marked-nodes-only-input--") + rows.append("--") + set_name = set(order) + for n in self.nodes: + if set(n.input) & set_name == set(n.input) and set( + n.output + ) & set_name != set(n.output): + rows.append( + "%s|%s(%s) -> [%s]" + % ( + n.op_type, + n.name or "", + ", ".join(map(nstr, n.input)), + ", ".join(n.output), + ) + ) + rows.append("--") + rows.append("--not-fully-marked-nodes--") + rows.append("--") + set_name = set(order) + for n in self.nodes: + if (set(n.input) & set_name) and set(n.input) & set_name != set( + n.input + ): + rows.append( + "%s|%s(%s) -> [%s]" + % ( + n.op_type, + n.name or "", + ", ".join(map(nstr, n.input)), + ", ".join(n.output), + ) + ) + rows.append("--") + rows.append("--unmarked-nodes--") + rows.append("--") + for n in self.nodes: + if not (set(n.input) & set_name): + rows.append( + "%s|%s(%s) -> [%s]" + % ( + n.op_type, + n.name or "", + ", ".join(map(nstr, n.input)), + ", ".join(n.output), + ) + ) + raise RuntimeError( "After %d iterations for %d nodes, still unable " "to sort names %r. The graph may be disconnected. " - "List of operators: %s" - % (n_iter, len(self.nodes), missing_names, "\n".join(rows)) + "List of operators:\n---------%s\n-------" + "\n-- ORDER--\n%s\n-- MISSING --\n%s" + % ( + n_iter, + len(self.nodes), + missing_names, + "\n".join(rows), + pprint.pformat(list(sorted([(v, k) for k, v in order.items()]))), + pprint.pformat(set(order.values()) & set(missing_names)), + ) ) # Update order diff --git a/skl2onnx/operator_converters/array_feature_extractor.py b/skl2onnx/operator_converters/array_feature_extractor.py index 99b7639e6..993e93be7 100644 --- a/skl2onnx/operator_converters/array_feature_extractor.py +++ b/skl2onnx/operator_converters/array_feature_extractor.py @@ -16,17 +16,14 @@ def convert_sklearn_array_feature_extractor( column_indices_name = scope.get_unique_variable_name("column_indices") for i, ind in enumerate(operator.column_indices): - if not isinstance(ind, int): - raise RuntimeError( - ( - "Column {0}:'{1}' indices must be specified " - "as integers. This error may happen when " - "column names are used to define a " - "ColumnTransformer. Column name in input data " - "do not necessarily match input variables " - "defined for the ONNX model." - ).format(i, ind) - ) + assert isinstance(ind, int), ( + "Column {0}:'{1}' indices must be specified " + "as integers. This error may happen when " + "column names are used to define a " + "ColumnTransformer. Column name in input data " + "do not necessarily match input variables " + "defined for the ONNX model." + ).format(i, ind) container.add_initializer( column_indices_name, onnx_proto.TensorProto.INT64, diff --git a/skl2onnx/operator_converters/pipelines.py b/skl2onnx/operator_converters/pipelines.py index adf8734c6..e11360bfe 100644 --- a/skl2onnx/operator_converters/pipelines.py +++ b/skl2onnx/operator_converters/pipelines.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from sklearn.base import is_classifier +from sklearn.pipeline import Pipeline from ..common._registration import register_converter from ..common._topology import Scope, Operator from ..common._container import ModelComponentContainer @@ -16,7 +17,7 @@ def convert_pipeline( inputs = operator.inputs for step in model.steps: step_model = step[1] - if is_classifier(step_model): + if is_classifier(step_model) or isinstance(step_model, Pipeline): scope.add_options(id(step_model), options={"zipmap": False}) container.add_options(id(step_model), options={"zipmap": False}) outputs = _parse_sklearn(scope, step_model, inputs, custom_parsers=None) @@ -52,7 +53,7 @@ def convert_feature_union( scope: Scope, operator: Operator, container: ModelComponentContainer ): raise NotImplementedError( - "This converter not needed so far. It is usually handled " "during parsing." + "This converter not needed so far. It is usually handled during parsing." ) @@ -60,10 +61,19 @@ def convert_column_transformer( scope: Scope, operator: Operator, container: ModelComponentContainer ): raise NotImplementedError( - "This converter not needed so far. It is usually handled " "during parsing." + "This converter not needed so far. It is usually handled during parsing." ) -register_converter("SklearnPipeline", convert_pipeline) +register_converter( + "SklearnPipeline", + convert_pipeline, + options={ + "zipmap": [True, False, "columns"], + "nocl": [True, False], + "output_class_labels": [False, True], + "raw_scores": [True, False], + }, +) register_converter("SklearnFeatureUnion", convert_feature_union) register_converter("SklearnColumnTransformer", convert_column_transformer) diff --git a/skl2onnx/operator_converters/stacking.py b/skl2onnx/operator_converters/stacking.py index 889081c63..a203dadfb 100644 --- a/skl2onnx/operator_converters/stacking.py +++ b/skl2onnx/operator_converters/stacking.py @@ -21,13 +21,15 @@ def _fetch_scores( container.add_options(id(model), {"raw_scores": raw_scores}) this_operator.inputs.append(inputs) if is_regressor: - output_proba = scope.declare_local_variable("variable", inputs.type.__class__()) + output_proba = scope.declare_local_variable( + "sv_variable", inputs.type.__class__() + ) this_operator.outputs.append(output_proba) else: - label_name = scope.declare_local_variable("label", Int64TensorType()) + label_name = scope.declare_local_variable("sv_label", Int64TensorType()) this_operator.outputs.append(label_name) output_proba = scope.declare_local_variable( - "probability_tensor", inputs.type.__class__() + "sv_prob_tensor", inputs.type.__class__() ) this_operator.outputs.append(output_proba) @@ -43,7 +45,9 @@ def _add_passthrough_connection(operator, predictions): def _transform_regressor(scope, operator, container, model): - merged_prob_tensor = scope.get_unique_variable_name("merged_probability_tensor") + merged_prob_tensor = scope.get_unique_variable_name( + "stack_tr_reg_merged_probability_tensor" + ) predictions = [ _fetch_scores(scope, container, est, operator.inputs[0], is_regressor=True) @@ -57,7 +61,7 @@ def _transform_regressor(scope, operator, container, model): def _transform(scope, operator, container, model): - merged_prob_tensor = scope.get_unique_variable_name("merged_probability_tensor") + merged_prob_tensor = scope.get_unique_variable_name("stack_tra_merged_prob_tensor") predictions = [ _fetch_scores( @@ -77,18 +81,18 @@ def _transform(scope, operator, container, model): for est_idx in range(0, len(op.estimators_)) ) if select_lact_column: - column_index_name = scope.get_unique_variable_name("column_index") + column_index_name = scope.get_unique_variable_name("stack_tr_column_index") container.add_initializer( column_index_name, onnx_proto.TensorProto.INT64, [], [1] ) new_predictions = [] for ipred, pred in enumerate(predictions): - prob1 = scope.get_unique_variable_name("stack_prob%d" % ipred) + prob1 = scope.get_unique_variable_name("stack_tr_prob%d" % ipred) container.add_node( "ArrayFeatureExtractor", [pred, column_index_name], prob1, - name=scope.get_unique_operator_name("ArrayFeatureExtractor"), + name=scope.get_unique_operator_name("SV_AFE"), op_domain="ai.onnx.ml", ) new_predictions.append(prob1) @@ -124,18 +128,16 @@ def convert_sklearn_stacking_classifier( else: classes = np.array([s.encode("utf-8") for s in classes]) - classes_name = scope.get_unique_variable_name("classes") - argmax_output_name = scope.get_unique_variable_name("argmax_output") - reshaped_result_name = scope.get_unique_variable_name("reshaped_result") - array_feature_extractor_result_name = scope.get_unique_variable_name( - "array_feature_extractor_result" - ) + classes_name = scope.get_unique_variable_name("stack_classes") + argmax_output_name = scope.get_unique_variable_name("stack_argmax_output") + reshaped_result_name = scope.get_unique_variable_name("stack_reshaped_result") + array_feature_extractor_result_name = scope.get_unique_variable_name("stack_result") container.add_initializer(classes_name, class_type, classes.shape, classes) merged_proba_tensor = _transform(scope, operator, container, stacking_op) merge_proba = scope.declare_local_variable( - "merged_stacked_proba", operator.inputs[0].type.__class__() + "stack_merge_proba", operator.inputs[0].type.__class__() ) container.add_node("Identity", [merged_proba_tensor], [merge_proba.onnx_name]) prob = _fetch_scores( @@ -149,13 +151,13 @@ def convert_sklearn_stacking_classifier( "Identity", prob, operator.outputs[1].onnx_name, - name=scope.get_unique_operator_name("OpProb"), + name=scope.get_unique_operator_name("Identity_OpProb"), ) container.add_node( "ArgMax", prob, argmax_output_name, - name=scope.get_unique_operator_name("ArgMax"), + name=scope.get_unique_operator_name("AFE_ArgMax"), axis=1, ) container.add_node( @@ -163,7 +165,7 @@ def convert_sklearn_stacking_classifier( [classes_name, argmax_output_name], array_feature_extractor_result_name, op_domain="ai.onnx.ml", - name=scope.get_unique_operator_name("ArrayFeatureExtractor"), + name=scope.get_unique_operator_name("SC_AFE_Stack"), ) if class_type == onnx_proto.TensorProto.INT32: diff --git a/tests/test_issues_2024.py b/tests/test_issues_2024.py index 364ef8b78..569e6a027 100644 --- a/tests/test_issues_2024.py +++ b/tests/test_issues_2024.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import unittest import packaging.version as pv +from sklearn.utils._testing import ignore_warnings +from sklearn.exceptions import ConvergenceWarning from onnxruntime import __version__ as ort_version @@ -45,6 +47,7 @@ def test_issue_1053(self): pv.Version(ort_version) < pv.Version("1.16.0"), reason="opset 19 not implemented", ) + @ignore_warnings(category=(ConvergenceWarning,)) def test_issue_1055(self): import numpy as np from numpy.testing import assert_almost_equal @@ -111,6 +114,157 @@ def test_issue_1055(self): got = sess.run(None, {"input": np.array(corpus).reshape((-1, 1))}) assert_almost_equal(expected, got[1], decimal=2) + @unittest.skipIf( + pv.Version(ort_version) < pv.Version("1.16.0"), + reason="opset 19 not implemented", + ) + def test_issue_1069(self): + import math + from typing import Any + import numpy + import pandas + from sklearn import ( + base, + compose, + ensemble, + linear_model, + pipeline, + preprocessing, + datasets, + ) + from sklearn.model_selection import train_test_split + from sklearn.tree import DecisionTreeClassifier + import onnxruntime + from skl2onnx import to_onnx + from skl2onnx.sklapi import CastTransformer + + class FLAGS: + classes = 7 + samples = 1000 + timesteps = 5 + trajectories = int(1000 / 5) + features = 10 + seed = 10 + + columns = [ + f"facelandmark{i}" for i in range(1, int(FLAGS.features / 2) + 1) + ] + [f"poselandmark{i}" for i in range(1, int(FLAGS.features / 2) + 1)] + + X, y = datasets.make_classification( + n_classes=FLAGS.classes, + n_informative=math.ceil(math.log2(FLAGS.classes * 2)), + n_samples=FLAGS.samples, + n_features=FLAGS.features, + random_state=FLAGS.seed, + ) + + X = pandas.DataFrame(X, columns=columns) + + X["trajectory"] = numpy.repeat( + numpy.arange(FLAGS.trajectories), FLAGS.timesteps + ) + X["timestep"] = numpy.tile(numpy.arange(FLAGS.timesteps), FLAGS.trajectories) + + trajectory_train, trajectory_test = train_test_split( + X["trajectory"].unique(), + test_size=0.25, + random_state=FLAGS.seed, + ) + + trajectory_train, trajectory_test = set(trajectory_train), set(trajectory_test) + + X_train, X_test = ( + X[X["trajectory"].isin(trajectory_train)], + X[X["trajectory"].isin(trajectory_test)], + ) + y_train, _ = y[X_train.index], y[X_test.index] + + def augment_with_lag_timesteps(X, k, columns): + augmented = X.copy() + + for i in range(1, k + 1): + shifted = X[columns].groupby(X["trajectory"]).shift(i) + shifted.columns = [f"{x}_lag{i}" for x in shifted.columns] + + augmented = pandas.concat([augmented, shifted], axis=1) + + return augmented + + X_train = augment_with_lag_timesteps(X_train, k=3, columns=X.columns[:-2]) + X_test = augment_with_lag_timesteps(X_test, k=3, columns=X.columns[:-2]) + + X_train.drop(columns=["trajectory", "timestep"], inplace=True) + X_test.drop(columns=["trajectory", "timestep"], inplace=True) + + def abc_Embedder() -> list[tuple[str, Any]]: + return [ + ("cast64", CastTransformer(dtype=numpy.float64)), + ("scaler", preprocessing.StandardScaler()), + ("cast32", CastTransformer()), + ("basemodel", DecisionTreeClassifier(max_depth=2)), + ] + + def Classifier(features: list[str]) -> base.BaseEstimator: + feats = [i for i, x in enumerate(features) if x.startswith("facelandmark")] + + classifier = ensemble.StackingClassifier( + estimators=[ + ( + "facepipeline", + pipeline.Pipeline( + [ + ( + "preprocessor", + compose.ColumnTransformer( + [("identity", "passthrough", feats)] + ), + ), + ("embedder", pipeline.Pipeline(steps=abc_Embedder())), + ] + ), + ), + ( + "posepipeline", + pipeline.Pipeline( + [ + ( + "preprocessor", + compose.ColumnTransformer( + [("identity", "passthrough", feats)] + ), + ), + ("embedder", pipeline.Pipeline(steps=abc_Embedder())), + ] + ), + ), + ], + final_estimator=linear_model.LogisticRegression( + multi_class="multinomial" + ), + ) + + return classifier + + model = Classifier(list(X_train.columns)) + model.fit(X_train, y_train) + + sample = X_train[:1].astype(numpy.float32) + + for m in [model.estimators_[0].steps[0][-1], model.estimators_[0], model]: + with self.subTest(model=type(m)): + exported = to_onnx( + model, + X=numpy.asarray(sample), + name="classifier", + target_opset={"": 12, "ai.onnx.ml": 2}, + options={id(model): {"zipmap": False}}, + ) + + modelengine = onnxruntime.InferenceSession( + exported.SerializeToString(), providers=["CPUExecutionProvider"] + ) + assert modelengine is not None + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_sklearn_pipeline_concat_tfidf.py b/tests/test_sklearn_pipeline_concat_tfidf.py index 8384dadcf..cda02b60f 100644 --- a/tests/test_sklearn_pipeline_concat_tfidf.py +++ b/tests/test_sklearn_pipeline_concat_tfidf.py @@ -320,6 +320,7 @@ def get_pipeline(N=10000): @unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available") @ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning)) @unittest.skipIf(not skl12(), reason="sparse_output") + @unittest.skipIf(TARGET_OPSET < 18, reason="too long") def test_issue_712_svc_binary(self): pipe, dfx_test = TestSklearnPipelineConcatTfIdf.get_pipeline() expected = pipe.transform(dfx_test) diff --git a/tests/test_sklearn_random_forest_converters.py b/tests/test_sklearn_random_forest_converters.py index d647dd166..9c6a60f6c 100644 --- a/tests/test_sklearn_random_forest_converters.py +++ b/tests/test_sklearn_random_forest_converters.py @@ -320,6 +320,7 @@ def test_random_forest_classifier_double(self): ) @ignore_warnings(category=FutureWarning) + @unittest.skipIf(TARGET_OPSET < 18, reason="too long") def test_model_random_forest_classifier_multi_output_int(self): model, X_test = fit_multi_output_classification_model( RandomForestClassifier(random_state=42, n_estimators=20) @@ -631,7 +632,7 @@ def test_boston_pca_rf(self): pipe, model_onnx, methods=["predict"], - basename="SklearnBostonPCARF-Dec4", + basename="SklearnBostonPCARF-Dec3", ) @ignore_warnings(category=FutureWarning) diff --git a/tests/test_sklearn_tfidf_vectorizer_converter_pipeline.py b/tests/test_sklearn_tfidf_vectorizer_converter_pipeline.py index 5872d0f09..65b2fbf33 100644 --- a/tests/test_sklearn_tfidf_vectorizer_converter_pipeline.py +++ b/tests/test_sklearn_tfidf_vectorizer_converter_pipeline.py @@ -168,4 +168,4 @@ def test_model_tfidf_vectorizer_pipeline_stop_words(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) diff --git a/tests/test_utils_sklearn.py b/tests/test_utils_sklearn.py index c6c1d3847..4802c15d5 100644 --- a/tests/test_utils_sklearn.py +++ b/tests/test_utils_sklearn.py @@ -8,6 +8,7 @@ import numpy import pandas from onnxruntime import __version__ as ort_version +from sklearn.utils._testing import ignore_warnings from sklearn import __version__ as skl_version from sklearn.linear_model import LinearRegression from sklearn.ensemble import RandomForestRegressor @@ -128,6 +129,7 @@ def test_pipeline_lr(self): pv.Version(ort_version) <= pv.Version("0.4.0"), reason="onnxruntime too old" ) @unittest.skipIf(not skl12(), reason="sparse_output") + @ignore_warnings(category=(RuntimeWarning,)) def test_pipeline_column_transformer(self): iris = load_iris() X = iris.data[:, :3] @@ -264,4 +266,4 @@ def test_pipeline_column_transformer(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2)