From 57b8241038ad99c2e7045ec832384821db5d7554 Mon Sep 17 00:00:00 2001 From: Kalle Raiskila Date: Sat, 13 May 2023 15:40:45 +0200 Subject: [PATCH] Track used output tensors explicitly. Kludge - see comments in this commit for explanation. --- src/graph.cc | 23 ++++++++++++++++++++++- src/node.cc | 15 +++------------ src/node.h | 16 +++++++++++----- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/graph.cc b/src/graph.cc index d7e110a..fad7e1a 100644 --- a/src/graph.cc +++ b/src/graph.cc @@ -270,7 +270,6 @@ bool Graph::tryResolveNode(onnx::NodeProto &node) const_cast(i)->consumers.push_back(n); } - n->onnx_node = &node; n->isResolved = false; n->op_name = new_node; n->onnx_name = node.name(); @@ -289,6 +288,28 @@ bool Graph::tryResolveNode(onnx::NodeProto &node) if( node.attribute_size() != 0 ) n->parseAttributes( node ); + // create output nodes for the tensor. + // this is a kludge around a chicken & egg problem caused by bad design in + // onnx2c: + // we don't want to save a copy of the onnx::NodeProto in the onnx2c::node object + // (since who knows how protobuf keeps its internals). + // So create a list of that tells if outputs are used or not *before* resolving + // the node. + std::vector output_used; + for(int nn = 0; nnset_output_used(output_used); + // Configure Node internals, and populate its outputs vector. n->resolve(); diff --git a/src/node.cc b/src/node.cc index 8108c3b..23b5b49 100644 --- a/src/node.cc +++ b/src/node.cc @@ -7,20 +7,11 @@ using namespace toC; int64_t Node::onnx_ir_version; -bool Node::is_output_N_used(unsigned N) +bool Node::is_output_N_used(unsigned N) const { - // ONNX spec: - // "There are two ways to leave an optional input or output unspecified: - // the first, available only for trailing inputs and outputs, is to simply - // not provide that input; the second method is to use an empty string in - // place of an input or output name." - - if( (int)N >= onnx_node->output_size() ) - return false; - - if( onnx_node->output(N) == "" ) + if( N >= output_used.size() ) return false; - + return output_used[N]; return true; } diff --git a/src/node.h b/src/node.h index 58c2b87..cdcf656 100644 --- a/src/node.h +++ b/src/node.h @@ -19,7 +19,6 @@ typedef std::tuple function_parameter; class Node { public: bool isResolved; // has this node been visited in current compilation step. - const onnx::NodeProto *onnx_node; std::string onnx_name; //ONNX name of the individual node std::string op_name; //ONNX name of node type static int64_t onnx_ir_version; @@ -28,12 +27,19 @@ class Node { // NB: this is deprecated. Whenever a node is updated, // any reference to this variable should be removed. // instead of outputs.push_back(), use register_output() + // Eventually this variable should be made protected std::vector outputs; private: std::vector input_params; std::vector output_params; + // truth table telling if the Nth output is used or not. + // This might not be as long as the number of outputs in the Node operand's specification + // (i.e .when trailing outputs are not used) + std::vector output_used; public: + void set_output_used(std::vectorval){output_used = val; } + // when output is removed, get the vector of tensors from output_params. std::vector get_outputs(void) const {return outputs;} @@ -47,7 +53,7 @@ class Node { /* Print the C implmementation of the operator */ - virtual void print(std::ostream &destination) const = 0; + virtual void print(std::ostream &destination) const = 0; /* Print comma-separated list of function parameters. * Unused optional tensors skipped. e.g.: @@ -77,8 +83,8 @@ class Node { /* Check if an optional output is used in the network. * N is Nth output specified in the Operator.md specification for this node. - * Start counting N from 0. */ - bool is_output_N_used(unsigned N); + * Start counting N from 0, including the non-optional outputs. */ + bool is_output_N_used(unsigned N) const; /* Not all node types have attributes. Override where needed */ virtual void parseAttributes( onnx::NodeProto &node ) @@ -87,7 +93,7 @@ class Node { } /* TODO: these should be part of class Tensor... */ - /* Check input constraints, as used in + /* Check input constraints, as used in * https://github.com/onnx/onnx/blob/master/docs/Operators.md */ /* (u)int32, (u)int64, float16/32/64, bfloat*/