From 28bbf3a21039116aa33465e977d1fadf46201e64 Mon Sep 17 00:00:00 2001 From: Zachary Tessler Date: Fri, 16 Dec 2022 14:55:35 -0500 Subject: [PATCH 1/2] parse featureunion with nested featureunion or columntransformer Signed-off-by: Zachary Tessler --- skl2onnx/_parse.py | 2 +- tests/test_sklearn_feature_union.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/skl2onnx/_parse.py b/skl2onnx/_parse.py index b64430ebe..2d07e9588 100644 --- a/skl2onnx/_parse.py +++ b/skl2onnx/_parse.py @@ -287,7 +287,7 @@ def _parse_sklearn_feature_union(scope, model, inputs, custom_parsers=None): # Encode each transform as our IR object for name, transform in model.transformer_list: transformed_result_names.append( - _parse_sklearn_simple_model( + _parse_sklearn( scope, transform, inputs, custom_parsers=custom_parsers)[0]) if (model.transformer_weights is not None and name in diff --git a/tests/test_sklearn_feature_union.py b/tests/test_sklearn_feature_union.py index f183eb7b3..46926e45e 100644 --- a/tests/test_sklearn_feature_union.py +++ b/tests/test_sklearn_feature_union.py @@ -38,6 +38,30 @@ def test_feature_union_default(self): dump_data_and_model(X_test, model, model_onnx, basename="SklearnFeatureUnionDefault") + @unittest.skipIf( + pv.Version(ort_version) <= pv.Version('0.4.0'), + reason="onnxruntime too old") + def test_feature_union_nested(self): + data = load_iris() + X, y = data.data, data.target + X = X.astype(np.float32) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, + random_state=42) + model = FeatureUnion([ + ('features', FeatureUnion([ + ('standard', StandardScaler()), + ]) + ), + ]).fit(X_train) + + model_onnx = convert_sklearn( + model, 'feature union', + [('input', FloatTensorType([None, X_test.shape[1]]))], + target_opset=TARGET_OPSET) + self.assertTrue(model_onnx is not None) + dump_data_and_model(X_test, model, model_onnx, + basename="SklearnFeatureUnionNested") + @unittest.skipIf( pv.Version(ort_version) <= pv.Version('0.4.0'), reason="onnxruntime too old") From f1774815aa8da42e6bcbc2ac4ca6a2b85bd7b0cb Mon Sep 17 00:00:00 2001 From: Zachary Tessler Date: Fri, 16 Dec 2022 16:01:56 -0500 Subject: [PATCH 2/2] flake8 line length Signed-off-by: Zachary Tessler --- tests/test_sklearn_feature_union.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sklearn_feature_union.py b/tests/test_sklearn_feature_union.py index 46926e45e..b40774ad1 100644 --- a/tests/test_sklearn_feature_union.py +++ b/tests/test_sklearn_feature_union.py @@ -45,8 +45,8 @@ def test_feature_union_nested(self): data = load_iris() X, y = data.data, data.target X = X.astype(np.float32) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, - random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.5, random_state=42) model = FeatureUnion([ ('features', FeatureUnion([ ('standard', StandardScaler()),