From 72ea78cd30610c7303533e9fc02d51f9764f3726 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 9 Feb 2022 15:19:41 +0100 Subject: [PATCH] Fix a validation verifying dimension type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: xavier dupré --- skl2onnx/common/_topology.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/skl2onnx/common/_topology.py b/skl2onnx/common/_topology.py index 7419a56e6..caf91db2d 100644 --- a/skl2onnx/common/_topology.py +++ b/skl2onnx/common/_topology.py @@ -150,7 +150,7 @@ def _check(self): for k in self.type.shape: if k is None: continue - if not isinstance(k, (int, np.int64, np.intc)): + if not isinstance(k, (int, np.integer)): raise ValueError( "Unexpected type %r for shape %r." "" % (type(k), self)) @@ -1090,11 +1090,14 @@ def _initialize_graph_status_for_traversing(self): for operator in self.unordered_operator_iterator(): operator.init_status(is_evaluated=False) - def _propagate_status(self, operator, container, fed_variables): + def _propagate_status(self, operator, container, fed_variables, + verbose=0): """ Propagates status *is_fed* based on output variable and node added in the container. """ + if verbose > 1: + print("[_propagate_status] after op=%r" % operator) vars = {} for node in container.nodes: for i in node.input: @@ -1102,8 +1105,10 @@ def _propagate_status(self, operator, container, fed_variables): vars[i] = [] vars[i].append(node) - stack = [v.onnx_name for v in operator.outputs if v.is_fed] - stack.extend(v.onnx_name for v in operator.inputs if v.is_fed) + if verbose > 1: + print("[_propagate_status] newly fed=%r" % list( + v.onnx_name for v in operator.outputs if v.is_fed)) + stack = list(fed_variables) scope = self.scopes[0] while len(stack) > 0: nodes = {} @@ -1117,11 +1122,15 @@ def _propagate_status(self, operator, container, fed_variables): if all(fed_variables.get(n, False) for n in node.input): for o in node.output: if o not in fed_variables: + if verbose > 1: + print("[_propagate_status] add=%r" % o) fed_variables[o] = o stack.append(o) if o in scope.variables: var = scope.variables[o] var.init_status(is_fed=True) + if verbose > 1: + print("[_propagate_status] fed=%r" % var) def convert_operators(self, container=None, verbose=0): """ @@ -1239,7 +1248,9 @@ def _check_variable_out_(variable, operator): fed_variables.update( {i.name: i for i in container.initializers if i.name not in fed_variables}) - self._propagate_status(operator, container, fed_variables) + self._propagate_status(operator, container, fed_variables, + verbose=verbose) + # unfed some variables (it happens when a node # shares an output with another node) rem = []