diff --git a/skl2onnx/algebra/graph_state.py b/skl2onnx/algebra/graph_state.py index 95969b4ca..befd5ac14 100644 --- a/skl2onnx/algebra/graph_state.py +++ b/skl2onnx/algebra/graph_state.py @@ -22,10 +22,12 @@ class GraphState: def __init__(self, inputs, output_names, operator_name, scope, container, converter, onnx_prefix_name=None, options=None, expected_inputs=None, - expected_outputs=None, operator=None, **attrs): + expected_outputs=None, operator=None, + run_converter=False, **attrs): self.inputs = inputs self._output_names = output_names self.scope = scope + self.run_converter = run_converter self.operator = operator if hasattr(operator_name, 'fit'): from .. import get_model_alias @@ -72,14 +74,19 @@ def __init__(self, inputs, output_names, operator_name, scope, change = [] for vi in v: change.append((vi, None) if isinstance(vi, str) else vi) + if self._output_names is not None: res = [] - for i in range(0, len(self._expected_outputs)): - if i < len(self._output_names): - res.append( - (self._output_names[i], self._expected_outputs[i][1])) - else: - res.append(self._expected_outputs[i]) + if self._expected_outputs is not None: + for i in range(0, len(self._expected_outputs)): + if i < len(self._output_names): + res.append( + (self._output_names[i], + self._expected_outputs[i][1])) + else: + res.append(self._expected_outputs[i]) + for i in range(len(res), len(self._output_names)): + res.append((self._output_names[i], None)) self._expected_outputs = res if self._expected_outputs is not None: @@ -161,6 +168,9 @@ def __fct__(var, operator): raise RuntimeError( "Mismatch number of outputs between %s and %s." % ( v, self._output_names[index])) + v2 = self.scope.get(var[0], None) + if v2 is not None: + v = [v2] if v[0][0] != self._output_names[index]: raise RuntimeError( "Mismatch output name %r between %s and %s." % ( @@ -375,44 +385,71 @@ def run(self): "Attribute 'sub_op_' is not empty.") # a model is converted into a subgraph - sub_op = self.scope.declare_local_operator( - self.operator_name, self.operator_instance) - sub_op.inputs = self.computed_inputs_ + sub_op_inputs = self.computed_inputs_ # output are not defined, we need to call a parser. from .._parse import _parse_sklearn self.scope.add_options( - id(sub_op.raw_operator), self.options) + id(self.operator_instance), self.options) sub_outputs = _parse_sklearn( - self.scope, self.operator_instance, sub_op.inputs) + self.scope, self.operator_instance, sub_op_inputs) + set_input_names = set(v.onnx_name for v in sub_op_inputs) + sub_op = None + for op in self.scope.operators.values(): + for inp in op.inputs: + if inp.onnx_name in set_input_names: + sub_op = op if (sub_outputs is None or None in sub_outputs): raise RuntimeError( "Wrong result when parsing model {}.".format( type(self.operator_instance))) - self.computed_outputs_ = [] + # Checks operator outputs for out in sub_outputs: - self.computed_outputs_.append( - out if isinstance(out, Variable) - else Variable( - v.raw_name, - self.scope.get_unique_variable_name(v.raw_name), - self.scope, v.type)) - - sub_op.outputs = self.computed_outputs_ + if not isinstance(out, Variable): + raise TypeError( + "Output %s must be of type Variable." % out) + self.sub_op_ = sub_op + sub_op.outputs = sub_outputs + shape_calc = get_shape_calculator(self.operator_name) shape_calc(sub_op) + + # Add Identity nodes to be consistent with `is_fed` + # in Topology. + if expected_outputs is not None: + outputs = [ + self._get_output_name( + self._output_names, o, self.scope) + for o in expected_outputs] + else: + outputs = [ + self.scope.declare_local_variable( + o.onnx_name, type=o.type) + for o in sub_op.outputs] + if len(outputs) != len(sub_op.outputs): + raise RuntimeError( + "Mismatched number of outputs %s and %s." % ( + outputs, sub_op.outputs)) + + output_names = [i[0] for i in outputs] + for i, out in enumerate(sub_op.outputs): + var = outputs[i] + self.container.add_node( + 'Identity', [out.onnx_name], [var[0]], + name=self.scope.get_unique_operator_name("SubOpId")) + self.computed_outputs_ = outputs self.computed_inputs2_ = sub_op.inputs - self.computed_inputs2_ = [ + self.computed_outputs2_ = [ (v.raw_name, v.type) for v in self.computed_outputs_] - self.sub_op_ = sub_op - self.computed_outputs_ = sub_op.outputs - # The parser was run on sub-operators and neither the shape - # calcutor nor the converter. - conv = get_converter(self.operator_name) - conv(self.scope, sub_op, self.container) + if self.run_converter: + # The parser was run on sub-operators and neither the shape + # calcutor nor the converter. + conv = get_converter(self.operator_name) + conv(self.scope, sub_op, self.container) + # sub_op.is_evaluated = True else: # only one node is added if self.options is not None: diff --git a/skl2onnx/algebra/onnx_operator.py b/skl2onnx/algebra/onnx_operator.py index e78b1b9bd..93e6fd697 100644 --- a/skl2onnx/algebra/onnx_operator.py +++ b/skl2onnx/algebra/onnx_operator.py @@ -5,7 +5,6 @@ # -------------------------------------------------------------------------- import numpy as np from scipy.sparse import coo_matrix -from onnxconverter_common.onnx_ops import apply_identity from ..proto import TensorProto from ..common.data_types import ( _guess_type_proto_str, _guess_type_proto_str_inv) @@ -113,61 +112,6 @@ def get_output_type_inference(self, input_shapes=None): return outputs[self.index:self.index + 1] -class OnnxSubOperator: - """ - Includes a sub operator in the ONNX graph. - """ - - def __init__(self, op, inputs, output_names=None, op_version=None, - options=None): - self.op = op - self.output_names = output_names - if not isinstance(inputs, list): - inputs = [inputs] - self.inputs = inputs - self.op_version = op_version - self.options = options - - def add_to(self, scope, container, operator=None): - """ - Adds outputs to the container if not already added, - registered the outputs if the node is not final. - - :param scope: scope - :param container: container - :param operator: overwrite inputs - """ - if operator is not None: - raise RuntimeError( - "operator must be None, the operator to convert " - "is specified in member 'op'.") - try: - op_type = sklearn_operator_name_map[type(self.op)] - except KeyError: - raise RuntimeError( - "Unable to find a converter for model of type '{}'." - "".format(self.op.__class__.__name__)) - - this_operator = scope.declare_local_operator(op_type, self.op) - this_operator.inputs = self.inputs - if self.output_names is None: - output = scope.declare_local_variable('sub_%s' % op_type) - this_operator.outputs.append(output) - self.outputs = [output] - else: - self.outputs = [] - for v in self.output_names: - if isinstance(v, Variable): - output = scope.declare_local_variable( - '%s_%s' % (v.onnx_name, op_type)) - apply_identity( - scope, output.onnx_name, v.onnx_name, container) - elif isinstance(v, str): - output = scope.declare_local_variable(v) - self.outputs.append(output) - this_operator.outputs.extend(self.outputs) - - class OnnxOperator: """ Ancestor to every *ONNX* operator exposed in @@ -278,6 +222,8 @@ def __init__(self, *inputs, op_version=None, output_names=None, "The class cannot infer the number of variables " "for node '{}' yet. output_names must be specified" ".".format(self.__class__.__name__)) + if isinstance(output_names, str): + output_names = [output_names] if op_version is None: if domain == '': @@ -340,7 +286,7 @@ def __init__(self, *inputs, op_version=None, output_names=None, if isinstance(inp, str): self.inputs.append(OnnxOperator.UnscopedVariable(inp)) elif isinstance(inp, (OnnxOperator, Variable, - OnnxOperatorItem, OnnxSubOperator)): + OnnxOperatorItem, OnnxSubEstimator)): self.inputs.append(inp) elif isinstance(inp, tuple) and len(inp) == 2: self.inputs.append(inp) @@ -401,8 +347,11 @@ def __init__(self, *inputs, op_version=None, output_names=None, if all(map(lambda x: x is None, self.output_variables)): self.output_variables = None - if (self.output_names is not None and - len(self.output_names) > len(self.expected_outputs)): + if (self.output_names is not None and ( + self.expected_outputs is None or + len(self.output_names) > len(self.expected_outputs))): + if self.expected_outputs is None: + self.expected_outputs = [] for i in range(len(self.expected_outputs), len(self.output_names)): self.expected_outputs.append((self.output_names[i], None)) @@ -986,6 +935,7 @@ def add_to(self, scope, container, operator=None, recursive=False): input[0], input[0], scope=scope, type=input[1])) else: inputs.append(input) + self.state = GraphState( inputs, self.output_names_, self.operator_instance, scope, container, None, op_version=self.op_version, diff --git a/skl2onnx/common/_topology.py b/skl2onnx/common/_topology.py index b8557b672..4c18cb089 100644 --- a/skl2onnx/common/_topology.py +++ b/skl2onnx/common/_topology.py @@ -6,6 +6,7 @@ import re import warnings +import pprint import numpy as np from onnx import onnx_pb as onnx_proto from onnxconverter_common.data_types import ( # noqa @@ -224,6 +225,14 @@ def __init__(self, onnx_name, scope, type, raw_operator, self.target_opset = target_opset self.scope_inst = scope_inst + def __repr__(self): + return ("Operator(type='{0}', onnx_name='{1}', inputs='{2}', " + "outputs='{3}', raw_operator={4})".format( + self.type, self.onnx_name, + ','.join(v.onnx_name for v in self.inputs), + ','.join(v.onnx_name for v in self.outputs), + self.raw_operator)) + @property def full_name(self): """ @@ -329,6 +338,10 @@ def __init__(self, name, parent_scopes=None, variable_name_set=None, # Reserved variables. self.reserved = {} + def get(self, var_name, default_value): + "Returns variable with 'name' or default value is not found." + return self.variables.get(var_name, default_value) + def temp(self): """ Creates a new Scope with the same options but no names. @@ -668,6 +681,7 @@ def topological_operator_iterator(self): raise TypeError( "operator.inputs must be a list not {}".format( type(operator.inputs))) + if (all(variable.is_fed for variable in operator.inputs) and not operator.is_evaluated): # Check if over-writing problem occurs (i.e., multiple @@ -676,18 +690,34 @@ def topological_operator_iterator(self): # Throw an error if this variable has been treated as # an output somewhere if variable.is_fed: + add = ["", "--DEBUG-INFO--"] + add.append("self.variable_name_set=%s" % ( + pprint.pformat(self.variable_name_set))) + add.append("self.operator_name_set=%s" % ( + pprint.pformat(self.operator_name_set))) + for scope in self.scopes: + add.append(pprint.pformat( + scope.variable_name_mapping)) + for var in scope.variables.values(): + add.append(" is_fed=%s %s" % ( + getattr(var, 'is_fed', '?'), var)) + for op in scope.operators.values(): + add.append(" is_evaluated=%s %s" % ( + getattr(op, 'is_evaluated', '?'), op)) raise RuntimeError( "A variable is already assigned ({}) " "for operator '{}' (name='{}'). This " "may still happen if a converter is a " - "combination of sub-operators and one of " + "combination of sub-estimators and one " "of them is producing this output. " "In that case, an identity node must be " - "added.".format( + "added.{}".format( variable, operator.type, - operator.onnx_name)) + operator.onnx_name, + "\n".join(add))) # Mark this variable as filled variable.is_fed = True + # Make this operator as handled operator.is_evaluated = True is_evaluation_happened = True diff --git a/tests/test_algebra_onnx_operators_wrapped.py b/tests/test_algebra_onnx_operators_wrapped.py index 722fda3d4..143d01247 100644 --- a/tests/test_algebra_onnx_operators_wrapped.py +++ b/tests/test_algebra_onnx_operators_wrapped.py @@ -8,7 +8,7 @@ from onnxruntime import InferenceSession from skl2onnx import to_onnx from skl2onnx.algebra.onnx_ops import OnnxIdentity -from skl2onnx.algebra.onnx_operator import OnnxSubOperator +from skl2onnx.algebra.onnx_operator import OnnxSubEstimator as SubOp from skl2onnx import update_registered_converter from onnxruntime import __version__ as ortv from test_utils import TARGET_OPSET @@ -59,7 +59,7 @@ def decorrelate_transformer_convertor(scope, operator, container): opv = container.target_opset out = operator.outputs X = operator.inputs[0] - subop = OnnxSubOperator(op.pca_, X, op_version=opv) + subop = SubOp(op.pca_, X, op_version=opv) Y = OnnxIdentity(subop, op_version=opv, output_names=out[:1]) Y.add_to(scope, container) @@ -69,8 +69,7 @@ def decorrelate_transformer_convertor2(scope, operator, container): opv = container.target_opset out = operator.outputs X = operator.inputs[0] - Y = OnnxSubOperator(op.pca_, X, op_version=opv, - output_names=out[:1]) + Y = SubOp(op.pca_, X, op_version=opv, output_names=out[:1]) Y.add_to(scope, container) @@ -91,9 +90,8 @@ def test_sub(self): decorrelate_transformer_convertor) onx = to_onnx(dec, X.astype(np.float32), target_opset=TARGET_OPSET) - + self.assertIn('output: "variable"', str(onx)) sess = InferenceSession(onx.SerializeToString()) - exp = dec.transform(X.astype(np.float32)) got = sess.run(None, {'X': X.astype(np.float32)})[0] assert_almost_equal(got, exp, decimal=4) @@ -113,9 +111,8 @@ def test_sub_double(self): decorrelate_transformer_convertor) onx = to_onnx(dec, X.astype(np.float64), target_opset=TARGET_OPSET) - + self.assertIn('output: "variable"', str(onx)) sess = InferenceSession(onx.SerializeToString()) - exp = dec.transform(X.astype(np.float64)) got = sess.run(None, {'X': X.astype(np.float64)})[0] assert_almost_equal(got, exp, decimal=4) @@ -135,9 +132,8 @@ def test_sub_output(self): decorrelate_transformer_convertor2) onx = to_onnx(dec, X.astype(np.float32), target_opset=TARGET_OPSET) - + self.assertIn('output: "variable"', str(onx)) sess = InferenceSession(onx.SerializeToString()) - exp = dec.transform(X.astype(np.float32)) got = sess.run(None, {'X': X.astype(np.float32)})[0] assert_almost_equal(got, exp, decimal=4) @@ -157,9 +153,8 @@ def test_sub_output_double(self): decorrelate_transformer_convertor2) onx = to_onnx(dec, X.astype(np.float64), target_opset=TARGET_OPSET) - + self.assertIn('output: "variable"', str(onx)) sess = InferenceSession(onx.SerializeToString()) - exp = dec.transform(X.astype(np.float64)) got = sess.run(None, {'X': X.astype(np.float64)})[0] assert_almost_equal(got, exp, decimal=4)