Skip to content

Commit

Permalink
parse featureunion with nested featureunion or columntransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
ztessler committed Dec 16, 2022
1 parent 4da35da commit 41a7a62
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _parse_sklearn_feature_union(scope, model, inputs, custom_parsers=None):
# Encode each transform as our IR object
for name, transform in model.transformer_list:
transformed_result_names.append(
_parse_sklearn_simple_model(
_parse_sklearn(
scope, transform, inputs,
custom_parsers=custom_parsers)[0])
if (model.transformer_weights is not None and name in
Expand Down
24 changes: 24 additions & 0 deletions tests/test_sklearn_feature_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,30 @@ def test_feature_union_default(self):
dump_data_and_model(X_test, model, model_onnx,
basename="SklearnFeatureUnionDefault")

@unittest.skipIf(
pv.Version(ort_version) <= pv.Version('0.4.0'),
reason="onnxruntime too old")
def test_feature_union_nested(self):
data = load_iris()
X, y = data.data, data.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5,
random_state=42)
model = FeatureUnion([
('features', FeatureUnion([
('standard', StandardScaler()),
])
),
]).fit(X_train)

model_onnx = convert_sklearn(
model, 'feature union',
[('input', FloatTensorType([None, X_test.shape[1]]))],
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
dump_data_and_model(X_test, model, model_onnx,
basename="SklearnFeatureUnionNested")

@unittest.skipIf(
pv.Version(ort_version) <= pv.Version('0.4.0'),
reason="onnxruntime too old")
Expand Down

0 comments on commit 41a7a62

Please sign in to comment.