Skip to content

Commit

Permalink
Remove Node local copies of input and output
Browse files Browse the repository at this point in the history
Having copies of the input and ouput tensors in a node
causes problems when trying to manipulate the graph.

Before now, there was no graph manipulating optimizations,
so this bad design could continue to be used. But now
it blocks the Cast-folding optimization pass.

Fix nodes up to (alphabetically) convinteger to comply.
  • Loading branch information
kraiskil committed Apr 22, 2023
1 parent 34f8af0 commit 1b33912
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 228 deletions.
15 changes: 2 additions & 13 deletions src/nodes/TEMPLATE
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,11 @@ class TEMPLATE : public Node {
public:
TEMPLATE() {
op_name = "TEMPLATE";
input_1=input_2_optional=output_1=output_2_optional=NULL;
}
/* Examples of ONNX Operand attributes */
std::vector<float> a_floatarray_attribute;
int an_int_attribute;

// input and output tensors
const Tensor *input_1;
const Tensor *input_2_optional;
const Tensor *output_1;
const Tensor *output_2_optional;


// Mandatory "API" functions towards the rest of onnx2c
virtual void parseAttributes( onnx::NodeProto &node ) override;
virtual void resolve(void) override;
Expand All @@ -60,14 +52,14 @@ void TEMPLATE::parseAttributes( onnx::NodeProto &node )
/* Assign input tensors, resolve output tensor shapes, allocate output tensors */
void TEMPLATE::resolve(void)
{
input_1 = inputs[0];
Tensor *input_1 = inputs[0];
// Remember the parameters to the generated function,
// along with a descriptive name that is used locally in the generated source.
// The most "descriptive name" usually is the one this tensor has in the ONNX documentation.
register_input(input_1, "A");

if (inputs.size() == 2) {
input_2_optional = inputs[1];
Tensor *input_2_optional = inputs[1];
register_input(input_2_optional, "descriptive_name");
}
// else leave input_2_optional as null so other functions here know to ignore it
Expand All @@ -78,9 +70,6 @@ void TEMPLATE::resolve(void)
Tensor *t = new Tensor;
t->data_dim.push_back(42);
t->data_type = onnx::TensorProto_DataType_FLOAT;
/* Store the created tensor both as reference in this node, and into
* the return value vector! */
output_1 = t;
register_output(t, "Y");

/* TODO: optional outputs? */
Expand Down
5 changes: 0 additions & 5 deletions src/nodes/averagepool.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@ class AveragePool : public Pooling {
public:
AveragePool() : Pooling() {
op_name = "AveragePool";
Indices = NULL;
}

// optional outputs
const Tensor *Indices;

virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override
{
INDT_3 << y->data_type_str() << " curavg = 0.0;" << std::endl;
Expand Down Expand Up @@ -80,7 +76,6 @@ class AveragePool : public Pooling {
Tensor *indices_out = new Tensor;
indices_out->data_type = onnx::TensorProto_DataType::TensorProto_DataType_INT64;
indices_out->data_dim = rv->data_dim;
Indices = indices_out;
register_output(indices_out, "ind");
}
};
Expand Down
71 changes: 20 additions & 51 deletions src/nodes/batchnormalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,13 @@ class BatchNormalization : public Node {
op_name = "BatchNormalization";
epsilon = 1e-5;
momentum = 0.9;
input=scale=bias=mean=var=output=NULL;
sqrt_var_offline = false;
}
bool sqrt_var_offline; // TODO: is it ever possible that we can't compute sqrt(var) offline?

float epsilon;
float momentum;

// inputs
const Tensor *input; // 'X' in the spec
const Tensor *scale;
const Tensor *bias; // 'B' in the spec
const Tensor *mean;
const Tensor *var;
// outputs
const Tensor *output;
// ... optional outputs not yet implmeneted


void parseAttribute_epsilon( const onnx::AttributeProto &a ) {
if( a.type() != onnx::AttributeProto_AttributeType_FLOAT )
Expand Down Expand Up @@ -78,6 +67,9 @@ class BatchNormalization : public Node {

virtual void print(std::ostream &dst) const override
{
Tensor *input = inputs[0];
const Tensor *scale = inputs[1];
const Tensor *bias = inputs[2];
int batch_size =input->data_dim[0];
int num_chan =input->data_dim[1];
std::string type = input->data_type_str();
Expand Down Expand Up @@ -114,10 +106,11 @@ class BatchNormalization : public Node {
dst << "( sqrt( var[c] + epsilon));" << std::endl;

INDT_2 << "output" << idxs << " = tmp_X ";
if( scale )
dst << "* scale[c]";
if( bias )
dst << " + bias[c]";

if( !isSplatted(scale, 1.0f) )
dst << "* scale[c]";
if( !isSplatted(bias, 0.0f) )
dst << " + bias[c]";
dst << ";" << std::endl;

for( unsigned i = 2; i<input->data_dim.size(); i++)
Expand All @@ -128,7 +121,7 @@ class BatchNormalization : public Node {
}

// TODO: this could be useful elsewhere too
bool isSplatted(const Tensor *t, float value)
bool isSplatted(const Tensor *t, float value) const
{
if( t->data_type != onnx::TensorProto_DataType_FLOAT )
ERROR("Unimplemented");
Expand All @@ -147,7 +140,7 @@ class BatchNormalization : public Node {
// Updates variance tensor in-place to contain the entire denominator
// of the BatchNormalization formula.
// TODO: This breaks if var is used anywere else.
void calculateSqrtVarOffline(void)
void calculateSqrtVarOffline(Tensor *var)
{
float *v = (float*)var->data_buffer;
for( int i=0; i<var->data_num_elem(); i++)
Expand All @@ -159,44 +152,20 @@ class BatchNormalization : public Node {
if( inputs.size() != 5 )
ERROR("wrong number of inputs to BatchNormalization");

input = inputs[0]; // "X"
register_input(input, "X");
scale = inputs[1];
register_input(scale, "scale");
bias = inputs[2]; // "B" in spec
register_input(bias, "bias");
mean = inputs[3];
register_input(mean, "mean");
var = inputs[4];
register_input(var, "var");

if( typeConstraint_plainFloatingPoints(input) == false)
ERROR("Incorrect input for node");
if( typeConstraint_plainFloatingPoints(scale) == false)
ERROR("Incorrect input for node");
if( typeConstraint_plainFloatingPoints(bias) == false)
ERROR("Incorrect input for node");
if( typeConstraint_plainFloatingPoints(mean) == false)
ERROR("Incorrect input for node");
if( typeConstraint_plainFloatingPoints(var) == false)
ERROR("Incorrect input for node");

// It is possible that scale is all ones, and bias is all zeros!
// But the scale and bias tensors are not optional inputs in ONNX, so they are always
// provided.
if( isSplatted(scale, 1.0f) )
scale = NULL;
if( isSplatted(bias, 0.0f) )
bias = NULL;
if( var->isConst ) {
calculateSqrtVarOffline();
register_input(inputs[0], "X");
register_input(inputs[1], "scale");
register_input(inputs[2], "bias");
register_input(inputs[3], "mean");
register_input(inputs[4], "var");

if( inputs[4]->isConst ) {
calculateSqrtVarOffline(inputs[4]);
sqrt_var_offline = true;
}

Tensor *rv = new Tensor;
rv->data_dim = input->data_dim;
rv->data_type = input->data_type;
output = rv;
rv->data_dim = inputs[0]->data_dim;
rv->data_type = inputs[0]->data_type;
register_output(rv, "output");
}
};
Expand Down
12 changes: 6 additions & 6 deletions src/nodes/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ void Cast::parseAttributes( onnx::NodeProto &node )

void Cast::resolve(void)
{
// TODO: should we warn user here. What is the use-case of 'Cast' in embedded systems?

input = inputs[0];
register_input(input, "input");
LOG(INFO) << "'Cast' node found." << std::endl;
register_input(inputs[0], "input");

switch(to)
{
Expand All @@ -37,16 +35,18 @@ void Cast::resolve(void)
}

Tensor *t = new Tensor;
t->data_dim = input->data_dim;
t->data_dim = inputs[0]->data_dim;
t->data_type = static_cast<onnx::TensorProto_DataType>(to);
output = t;
register_output(t, "output");
}


void Cast::print(std::ostream &dst) const
{
INDT_1 << "/* Cast */" << std::endl;
const Tensor *input = inputs[0];
const Tensor *output = outputs[0];

std::string intype = input->data_type_str();
std::string outtype = output->data_type_str();

Expand Down
5 changes: 0 additions & 5 deletions src/nodes/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@ class Cast : public Node {
public:
Cast() {
op_name = "Cast";
input=output=NULL;
to=-1;
}
int to;

std::string output_type;

// input and output tensors
const Tensor *input;
const Tensor *output;

virtual void parseAttributes( onnx::NodeProto &node ) override;
virtual void resolve(void) override;
virtual void print(std::ostream &dst) const override;
Expand Down
56 changes: 17 additions & 39 deletions src/nodes/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ class Clip : public Node {
public:
Clip() {
op_name = "Clip";
input=min_tensor=max_tensor=output=NULL;
min_attr = std::numeric_limits<float>::lowest();
max_attr = std::numeric_limits<float>::max();
}
Expand All @@ -19,11 +18,6 @@ class Clip : public Node {
// works for all versions of Clip
float min_attr, max_attr;

const Tensor *input;
const Tensor *min_tensor;
const Tensor *max_tensor;
const Tensor *output;


virtual void parseAttributes( onnx::NodeProto &node ) override {
for( const auto& a : node.attribute() ) {
Expand All @@ -40,55 +34,40 @@ class Clip : public Node {

virtual void resolve(void) override
{
input = inputs[0];

const Tensor *input = inputs[0];
register_input(inputs[0], "input");
if (inputs.size() > 1 && inputs[1]->is_used())
min_tensor = inputs[1];
register_input(inputs[1], "min_tensor");
if (inputs.size() > 2 && inputs[2]->is_used())
max_tensor = inputs[2];

register_input(inputs[2], "max_tensor");

Tensor *t = new Tensor;
t->data_dim = input->data_dim;
t->data_type = input->data_type;
/* Store the created tensor both as reference in this node, and into
* the return value vector! */
output = t;
outputs.push_back(t);
register_output(t, "output");
}


virtual void print_parameters(std::ostream &dst, bool decorate ) const override
virtual void print(std::ostream &dst) const override
{
input->print_tensor_as_const(dst, !decorate);

if (min_tensor) {
dst << ", ";
min_tensor->print_tensor_as_const(dst, !decorate);
}

if (max_tensor) {
dst << ", ";
max_tensor->print_tensor_as_const(dst, !decorate);
}

dst << ", ";
output->print_tensor(dst, !decorate);
}

const Tensor *input = inputs[0];
const Tensor *min_tensor = nullptr;
const Tensor *max_tensor = nullptr;

virtual void print(std::ostream &dst) const override
{
if (inputs.size() > 1 && inputs[1]->is_used())
min_tensor = inputs[1];
if (inputs.size() > 2 && inputs[2]->is_used())
max_tensor = inputs[2];

INDT_1 << "/* Clip */" << std::endl;

if( min_tensor )
INDT_1 << min_tensor->data_type_str() << " minv = " << min_tensor->cname() << "[0];" << std::endl;
INDT_1 << min_tensor->data_type_str() << " minv = min_tensor[0];" << std::endl;
else
INDT_1 << "float minv = " << min_attr << ";" << std::endl;

if( max_tensor )
INDT_1 << max_tensor->data_type_str() << " maxv = " << max_tensor->cname() << "[0];" << std::endl;
INDT_1 << max_tensor->data_type_str() << " maxv = max_tensor[0];" << std::endl;
else
INDT_1 << "float maxv = " << max_attr << ";" << std::endl;

Expand All @@ -102,13 +81,12 @@ class Clip : public Node {
idx += "[" + lv + "]";
}

INDT_2 << output->cname() << idx << " = ";
dst << "MAX( MIN( " << input->cname() << idx << ", maxv), minv);" << std::endl;
INDT_2 << "output" << idx << " = ";
dst << "MAX( MIN( input"<< idx << ", maxv), minv);" << std::endl;

for( unsigned r=0; r<input->rank(); r++) {
INDT_1 << "}" << std::endl;
}

}
};
}
Expand Down
Loading

0 comments on commit 1b33912

Please sign in to comment.