Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate issue 1069 #1072

Merged
merged 9 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

## 1.17.0 (development)

* Fix the conversion of pipeline including pipelines,
issue [#1069](https://github.com/onnx/sklearn-onnx/pull/1069),
[#1072](https://github.com/onnx/sklearn-onnx/pull/1072)
* Fix unexpected type for intercept in PoissonRegressor and GammaRegressor
[#1070](https://github.com/onnx/sklearn-onnx/pull/1070)
* Add support for scikti-learn 1.4.0,
* Add support for scikit-learn 1.4.0,
[#1058](https://github.com/onnx/sklearn-onnx/pull/1058),
fixes issues [Many examples in the gallery are showing "broken"](https://github.com/onnx/sklearn-onnx/pull/1057),
[TFIDF vectorizer target_opset issue](https://github.com/onnx/sklearn-onnx/pull/1055),
Expand Down
8 changes: 7 additions & 1 deletion skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ def _parse_sklearn_pipeline(scope, model, inputs, custom_parsers=None):
:return: A list of output variables produced by the input pipeline
"""
for step in model.steps:
if (
scope.options is None
or (id(step[1]) not in scope.options and type(step[1]) not in scope.options)
) and is_classifier(step[1]):
options = scope.get_options(model, {"zipmap": True})
scope.add_options(id(step[1]), options)
inputs = _parse_sklearn(scope, step[1], inputs, custom_parsers=custom_parsers)
return inputs

Expand Down Expand Up @@ -543,7 +549,7 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):

if options.get("output_class_labels", False):
raise RuntimeError(
"Option 'output_class_labels' is not compatible with option " "'zipmap'."
"Option 'output_class_labels' is not compatible with option 'zipmap'."
)

return _apply_zipmap(
Expand Down
82 changes: 78 additions & 4 deletions skl2onnx/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def ensure_topological_order(self):

def nstr(name):
if name in order:
return "%s#%d" % (name, order[name])
return "%s -- #%d" % (name, order[name])
return name

rows = [
Expand All @@ -937,17 +937,91 @@ def nstr(name):
"%s|%s(%s) -> [%s]"
% (
n.op_type,
n.name or n.op_type,
n.name or "",
", ".join(map(nstr, n.input)),
", ".join(n.output),
)
for n in self.nodes
)
rows.append("--")
rows.append("--full-marked-nodes-with-output--")
rows.append("--")
set_name = set(order)
for n in self.nodes:
if set(n.input) & set_name == set(n.input) and set(
n.output
) & set_name == set(n.output):
rows.append(
"%s|%s(%s) -> [%s]"
% (
n.op_type,
n.name or "",
", ".join(map(nstr, n.input)),
", ".join(n.output),
)
)
rows.append("--")
rows.append("--full-marked-nodes-only-input--")
rows.append("--")
set_name = set(order)
for n in self.nodes:
if set(n.input) & set_name == set(n.input) and set(
n.output
) & set_name != set(n.output):
rows.append(
"%s|%s(%s) -> [%s]"
% (
n.op_type,
n.name or "",
", ".join(map(nstr, n.input)),
", ".join(n.output),
)
)
rows.append("--")
rows.append("--not-fully-marked-nodes--")
rows.append("--")
set_name = set(order)
for n in self.nodes:
if (set(n.input) & set_name) and set(n.input) & set_name != set(
n.input
):
rows.append(
"%s|%s(%s) -> [%s]"
% (
n.op_type,
n.name or "",
", ".join(map(nstr, n.input)),
", ".join(n.output),
)
)
rows.append("--")
rows.append("--unmarked-nodes--")
rows.append("--")
for n in self.nodes:
if not (set(n.input) & set_name):
rows.append(
"%s|%s(%s) -> [%s]"
% (
n.op_type,
n.name or "",
", ".join(map(nstr, n.input)),
", ".join(n.output),
)
)

raise RuntimeError(
"After %d iterations for %d nodes, still unable "
"to sort names %r. The graph may be disconnected. "
"List of operators: %s"
% (n_iter, len(self.nodes), missing_names, "\n".join(rows))
"List of operators:\n---------%s\n-------"
"\n-- ORDER--\n%s\n-- MISSING --\n%s"
% (
n_iter,
len(self.nodes),
missing_names,
"\n".join(rows),
pprint.pformat(list(sorted([(v, k) for k, v in order.items()]))),
pprint.pformat(set(order.values()) & set(missing_names)),
)
)

# Update order
Expand Down
19 changes: 8 additions & 11 deletions skl2onnx/operator_converters/array_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@ def convert_sklearn_array_feature_extractor(
column_indices_name = scope.get_unique_variable_name("column_indices")

for i, ind in enumerate(operator.column_indices):
if not isinstance(ind, int):
raise RuntimeError(
(
"Column {0}:'{1}' indices must be specified "
"as integers. This error may happen when "
"column names are used to define a "
"ColumnTransformer. Column name in input data "
"do not necessarily match input variables "
"defined for the ONNX model."
).format(i, ind)
)
assert isinstance(ind, int), (
"Column {0}:'{1}' indices must be specified "
"as integers. This error may happen when "
"column names are used to define a "
"ColumnTransformer. Column name in input data "
"do not necessarily match input variables "
"defined for the ONNX model."
).format(i, ind)
container.add_initializer(
column_indices_name,
onnx_proto.TensorProto.INT64,
Expand Down
18 changes: 14 additions & 4 deletions skl2onnx/operator_converters/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from sklearn.base import is_classifier
from sklearn.pipeline import Pipeline
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
Expand All @@ -16,7 +17,7 @@ def convert_pipeline(
inputs = operator.inputs
for step in model.steps:
step_model = step[1]
if is_classifier(step_model):
if is_classifier(step_model) or isinstance(step_model, Pipeline):
scope.add_options(id(step_model), options={"zipmap": False})
container.add_options(id(step_model), options={"zipmap": False})
outputs = _parse_sklearn(scope, step_model, inputs, custom_parsers=None)
Expand Down Expand Up @@ -52,18 +53,27 @@ def convert_feature_union(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
raise NotImplementedError(
"This converter not needed so far. It is usually handled " "during parsing."
"This converter not needed so far. It is usually handled during parsing."
)


def convert_column_transformer(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
raise NotImplementedError(
"This converter not needed so far. It is usually handled " "during parsing."
"This converter not needed so far. It is usually handled during parsing."
)


register_converter("SklearnPipeline", convert_pipeline)
register_converter(
"SklearnPipeline",
convert_pipeline,
options={
"zipmap": [True, False, "columns"],
"nocl": [True, False],
"output_class_labels": [False, True],
"raw_scores": [True, False],
},
)
register_converter("SklearnFeatureUnion", convert_feature_union)
register_converter("SklearnColumnTransformer", convert_column_transformer)
38 changes: 20 additions & 18 deletions skl2onnx/operator_converters/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def _fetch_scores(
container.add_options(id(model), {"raw_scores": raw_scores})
this_operator.inputs.append(inputs)
if is_regressor:
output_proba = scope.declare_local_variable("variable", inputs.type.__class__())
output_proba = scope.declare_local_variable(
"sv_variable", inputs.type.__class__()
)
this_operator.outputs.append(output_proba)
else:
label_name = scope.declare_local_variable("label", Int64TensorType())
label_name = scope.declare_local_variable("sv_label", Int64TensorType())
this_operator.outputs.append(label_name)
output_proba = scope.declare_local_variable(
"probability_tensor", inputs.type.__class__()
"sv_prob_tensor", inputs.type.__class__()
)
this_operator.outputs.append(output_proba)

Expand All @@ -43,7 +45,9 @@ def _add_passthrough_connection(operator, predictions):


def _transform_regressor(scope, operator, container, model):
merged_prob_tensor = scope.get_unique_variable_name("merged_probability_tensor")
merged_prob_tensor = scope.get_unique_variable_name(
"stack_tr_reg_merged_probability_tensor"
)

predictions = [
_fetch_scores(scope, container, est, operator.inputs[0], is_regressor=True)
Expand All @@ -57,7 +61,7 @@ def _transform_regressor(scope, operator, container, model):


def _transform(scope, operator, container, model):
merged_prob_tensor = scope.get_unique_variable_name("merged_probability_tensor")
merged_prob_tensor = scope.get_unique_variable_name("stack_tra_merged_prob_tensor")

predictions = [
_fetch_scores(
Expand All @@ -77,18 +81,18 @@ def _transform(scope, operator, container, model):
for est_idx in range(0, len(op.estimators_))
)
if select_lact_column:
column_index_name = scope.get_unique_variable_name("column_index")
column_index_name = scope.get_unique_variable_name("stack_tr_column_index")
container.add_initializer(
column_index_name, onnx_proto.TensorProto.INT64, [], [1]
)
new_predictions = []
for ipred, pred in enumerate(predictions):
prob1 = scope.get_unique_variable_name("stack_prob%d" % ipred)
prob1 = scope.get_unique_variable_name("stack_tr_prob%d" % ipred)
container.add_node(
"ArrayFeatureExtractor",
[pred, column_index_name],
prob1,
name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
name=scope.get_unique_operator_name("SV_AFE"),
op_domain="ai.onnx.ml",
)
new_predictions.append(prob1)
Expand Down Expand Up @@ -124,18 +128,16 @@ def convert_sklearn_stacking_classifier(
else:
classes = np.array([s.encode("utf-8") for s in classes])

classes_name = scope.get_unique_variable_name("classes")
argmax_output_name = scope.get_unique_variable_name("argmax_output")
reshaped_result_name = scope.get_unique_variable_name("reshaped_result")
array_feature_extractor_result_name = scope.get_unique_variable_name(
"array_feature_extractor_result"
)
classes_name = scope.get_unique_variable_name("stack_classes")
argmax_output_name = scope.get_unique_variable_name("stack_argmax_output")
reshaped_result_name = scope.get_unique_variable_name("stack_reshaped_result")
array_feature_extractor_result_name = scope.get_unique_variable_name("stack_result")

container.add_initializer(classes_name, class_type, classes.shape, classes)

merged_proba_tensor = _transform(scope, operator, container, stacking_op)
merge_proba = scope.declare_local_variable(
"merged_stacked_proba", operator.inputs[0].type.__class__()
"stack_merge_proba", operator.inputs[0].type.__class__()
)
container.add_node("Identity", [merged_proba_tensor], [merge_proba.onnx_name])
prob = _fetch_scores(
Expand All @@ -149,21 +151,21 @@ def convert_sklearn_stacking_classifier(
"Identity",
prob,
operator.outputs[1].onnx_name,
name=scope.get_unique_operator_name("OpProb"),
name=scope.get_unique_operator_name("Identity_OpProb"),
)
container.add_node(
"ArgMax",
prob,
argmax_output_name,
name=scope.get_unique_operator_name("ArgMax"),
name=scope.get_unique_operator_name("AFE_ArgMax"),
axis=1,
)
container.add_node(
"ArrayFeatureExtractor",
[classes_name, argmax_output_name],
array_feature_extractor_result_name,
op_domain="ai.onnx.ml",
name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
name=scope.get_unique_operator_name("SC_AFE_Stack"),
)

if class_type == onnx_proto.TensorProto.INT32:
Expand Down
Loading
Loading