Skip to content

Commit

Permalink
Track used output tensors explicitly.
Browse files Browse the repository at this point in the history
Kludge - see comments in this commit for explanation.
  • Loading branch information
kraiskil committed May 18, 2023
1 parent 74abde7 commit 57b8241
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
23 changes: 22 additions & 1 deletion src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ bool Graph::tryResolveNode(onnx::NodeProto &node)
const_cast<Tensor*>(i)->consumers.push_back(n);
}

n->onnx_node = &node;
n->isResolved = false;
n->op_name = new_node;
n->onnx_name = node.name();
Expand All @@ -289,6 +288,28 @@ bool Graph::tryResolveNode(onnx::NodeProto &node)
if( node.attribute_size() != 0 )
n->parseAttributes( node );

// create output nodes for the tensor.
// this is a kludge around a chicken & egg problem caused by bad design in
// onnx2c:
// we don't want to save a copy of the onnx::NodeProto in the onnx2c::node object
// (since who knows how protobuf keeps its internals).
// So create a list of that tells if outputs are used or not *before* resolving
// the node.
std::vector<bool> output_used;
for(int nn = 0; nn<node.output_size(); nn++)
{
// ONNX spec:
// "There are two ways to leave an optional input or output unspecified:
// the first, available only for trailing inputs and outputs, is to simply
// not provide that input; the second method is to use an empty string in
// place of an input or output name."
if( node.output(nn) == "" )
output_used.push_back(false);
else
output_used.push_back(true);
}
n->set_output_used(output_used);

// Configure Node internals, and populate its outputs vector.
n->resolve();

Expand Down
15 changes: 3 additions & 12 deletions src/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,11 @@
using namespace toC;

int64_t Node::onnx_ir_version;
bool Node::is_output_N_used(unsigned N)
bool Node::is_output_N_used(unsigned N) const
{
// ONNX spec:
// "There are two ways to leave an optional input or output unspecified:
// the first, available only for trailing inputs and outputs, is to simply
// not provide that input; the second method is to use an empty string in
// place of an input or output name."

if( (int)N >= onnx_node->output_size() )
return false;

if( onnx_node->output(N) == "" )
if( N >= output_used.size() )
return false;

return output_used[N];
return true;
}

Expand Down
16 changes: 11 additions & 5 deletions src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ typedef std::tuple<const Tensor *, std::string> function_parameter;
class Node {
public:
bool isResolved; // has this node been visited in current compilation step.
const onnx::NodeProto *onnx_node;
std::string onnx_name; //ONNX name of the individual node
std::string op_name; //ONNX name of node type
static int64_t onnx_ir_version;
Expand All @@ -28,12 +27,19 @@ class Node {
// NB: this is deprecated. Whenever a node is updated,
// any reference to this variable should be removed.
// instead of outputs.push_back(), use register_output()
// Eventually this variable should be made protected
std::vector<Tensor *> outputs;
private:
std::vector<function_parameter> input_params;
std::vector<function_parameter> output_params;
// truth table telling if the Nth output is used or not.
// This might not be as long as the number of outputs in the Node operand's specification
// (i.e .when trailing outputs are not used)
std::vector<bool> output_used;

public:
void set_output_used(std::vector<bool>val){output_used = val; }

// when output is removed, get the vector of tensors from output_params.
std::vector<Tensor *> get_outputs(void) const {return outputs;}

Expand All @@ -47,7 +53,7 @@ class Node {


/* Print the C implmementation of the operator */
virtual void print(std::ostream &destination) const = 0;
virtual void print(std::ostream &destination) const = 0;

/* Print comma-separated list of function parameters.
* Unused optional tensors skipped. e.g.:
Expand Down Expand Up @@ -77,8 +83,8 @@ class Node {

/* Check if an optional output is used in the network.
* N is Nth output specified in the Operator.md specification for this node.
* Start counting N from 0. */
bool is_output_N_used(unsigned N);
* Start counting N from 0, including the non-optional outputs. */
bool is_output_N_used(unsigned N) const;

/* Not all node types have attributes. Override where needed */
virtual void parseAttributes( onnx::NodeProto &node )
Expand All @@ -87,7 +93,7 @@ class Node {
}

/* TODO: these should be part of class Tensor... */
/* Check input constraints, as used in
/* Check input constraints, as used in
* https://github.com/onnx/onnx/blob/master/docs/Operators.md
*/
/* (u)int32, (u)int64, float16/32/64, bfloat*/
Expand Down

0 comments on commit 57b8241

Please sign in to comment.