diff --git a/src/nodes/TEMPLATE b/src/nodes/TEMPLATE index b44684c..94109a6 100644 --- a/src/nodes/TEMPLATE +++ b/src/nodes/TEMPLATE @@ -22,19 +22,11 @@ class TEMPLATE : public Node { public: TEMPLATE() { op_name = "TEMPLATE"; - input_1=input_2_optional=output_1=output_2_optional=NULL; } /* Examples of ONNX Operand attributes */ std::vector a_floatarray_attribute; int an_int_attribute; - // input and output tensors - const Tensor *input_1; - const Tensor *input_2_optional; - const Tensor *output_1; - const Tensor *output_2_optional; - - // Mandatory "API" functions towards the rest of onnx2c virtual void parseAttributes( onnx::NodeProto &node ) override; virtual void resolve(void) override; @@ -60,14 +52,14 @@ void TEMPLATE::parseAttributes( onnx::NodeProto &node ) /* Assign input tensors, resolve output tensor shapes, allocate output tensors */ void TEMPLATE::resolve(void) { - input_1 = inputs[0]; + Tensor *input_1 = inputs[0]; // Remember the parameters to the generated function, // along with a descriptive name that is used locally in the generated source. // The most "descriptive name" usually is the one this tensor has in the ONNX documentation. register_input(input_1, "A"); if (inputs.size() == 2) { - input_2_optional = inputs[1]; + Tensor *input_2_optional = inputs[1]; register_input(input_2_optional, "descriptive_name"); } // else leave input_2_optional as null so other functions here know to ignore it @@ -78,9 +70,6 @@ void TEMPLATE::resolve(void) Tensor *t = new Tensor; t->data_dim.push_back(42); t->data_type = onnx::TensorProto_DataType_FLOAT; - /* Store the created tensor both as reference in this node, and into - * the return value vector! */ - output_1 = t; register_output(t, "Y"); /* TODO: optional outputs? */ diff --git a/src/nodes/averagepool.h b/src/nodes/averagepool.h index eeb4849..1818cc8 100644 --- a/src/nodes/averagepool.h +++ b/src/nodes/averagepool.h @@ -13,12 +13,8 @@ class AveragePool : public Pooling { public: AveragePool() : Pooling() { op_name = "AveragePool"; - Indices = NULL; } - // optional outputs - const Tensor *Indices; - virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override { INDT_3 << y->data_type_str() << " curavg = 0.0;" << std::endl; @@ -80,7 +76,6 @@ class AveragePool : public Pooling { Tensor *indices_out = new Tensor; indices_out->data_type = onnx::TensorProto_DataType::TensorProto_DataType_INT64; indices_out->data_dim = rv->data_dim; - Indices = indices_out; register_output(indices_out, "ind"); } }; diff --git a/src/nodes/batchnormalization.h b/src/nodes/batchnormalization.h index 85cba65..98b10c4 100644 --- a/src/nodes/batchnormalization.h +++ b/src/nodes/batchnormalization.h @@ -24,7 +24,6 @@ class BatchNormalization : public Node { op_name = "BatchNormalization"; epsilon = 1e-5; momentum = 0.9; - input=scale=bias=mean=var=output=NULL; sqrt_var_offline = false; } bool sqrt_var_offline; // TODO: is it ever possible that we can't compute sqrt(var) offline? @@ -32,16 +31,6 @@ class BatchNormalization : public Node { float epsilon; float momentum; - // inputs - const Tensor *input; // 'X' in the spec - const Tensor *scale; - const Tensor *bias; // 'B' in the spec - const Tensor *mean; - const Tensor *var; - // outputs - const Tensor *output; - // ... optional outputs not yet implmeneted - void parseAttribute_epsilon( const onnx::AttributeProto &a ) { if( a.type() != onnx::AttributeProto_AttributeType_FLOAT ) @@ -78,6 +67,9 @@ class BatchNormalization : public Node { virtual void print(std::ostream &dst) const override { + Tensor *input = inputs[0]; + const Tensor *scale = inputs[1]; + const Tensor *bias = inputs[2]; int batch_size =input->data_dim[0]; int num_chan =input->data_dim[1]; std::string type = input->data_type_str(); @@ -114,10 +106,11 @@ class BatchNormalization : public Node { dst << "( sqrt( var[c] + epsilon));" << std::endl; INDT_2 << "output" << idxs << " = tmp_X "; - if( scale ) - dst << "* scale[c]"; - if( bias ) - dst << " + bias[c]"; + + if( !isSplatted(scale, 1.0f) ) + dst << "* scale[c]"; + if( !isSplatted(bias, 0.0f) ) + dst << " + bias[c]"; dst << ";" << std::endl; for( unsigned i = 2; idata_dim.size(); i++) @@ -128,7 +121,7 @@ class BatchNormalization : public Node { } // TODO: this could be useful elsewhere too - bool isSplatted(const Tensor *t, float value) + bool isSplatted(const Tensor *t, float value) const { if( t->data_type != onnx::TensorProto_DataType_FLOAT ) ERROR("Unimplemented"); @@ -147,7 +140,7 @@ class BatchNormalization : public Node { // Updates variance tensor in-place to contain the entire denominator // of the BatchNormalization formula. // TODO: This breaks if var is used anywere else. - void calculateSqrtVarOffline(void) + void calculateSqrtVarOffline(Tensor *var) { float *v = (float*)var->data_buffer; for( int i=0; idata_num_elem(); i++) @@ -159,44 +152,20 @@ class BatchNormalization : public Node { if( inputs.size() != 5 ) ERROR("wrong number of inputs to BatchNormalization"); - input = inputs[0]; // "X" - register_input(input, "X"); - scale = inputs[1]; - register_input(scale, "scale"); - bias = inputs[2]; // "B" in spec - register_input(bias, "bias"); - mean = inputs[3]; - register_input(mean, "mean"); - var = inputs[4]; - register_input(var, "var"); - - if( typeConstraint_plainFloatingPoints(input) == false) - ERROR("Incorrect input for node"); - if( typeConstraint_plainFloatingPoints(scale) == false) - ERROR("Incorrect input for node"); - if( typeConstraint_plainFloatingPoints(bias) == false) - ERROR("Incorrect input for node"); - if( typeConstraint_plainFloatingPoints(mean) == false) - ERROR("Incorrect input for node"); - if( typeConstraint_plainFloatingPoints(var) == false) - ERROR("Incorrect input for node"); - - // It is possible that scale is all ones, and bias is all zeros! - // But the scale and bias tensors are not optional inputs in ONNX, so they are always - // provided. - if( isSplatted(scale, 1.0f) ) - scale = NULL; - if( isSplatted(bias, 0.0f) ) - bias = NULL; - if( var->isConst ) { - calculateSqrtVarOffline(); + register_input(inputs[0], "X"); + register_input(inputs[1], "scale"); + register_input(inputs[2], "bias"); + register_input(inputs[3], "mean"); + register_input(inputs[4], "var"); + + if( inputs[4]->isConst ) { + calculateSqrtVarOffline(inputs[4]); sqrt_var_offline = true; } Tensor *rv = new Tensor; - rv->data_dim = input->data_dim; - rv->data_type = input->data_type; - output = rv; + rv->data_dim = inputs[0]->data_dim; + rv->data_type = inputs[0]->data_type; register_output(rv, "output"); } }; diff --git a/src/nodes/cast.cc b/src/nodes/cast.cc index b1fbc78..32645b6 100644 --- a/src/nodes/cast.cc +++ b/src/nodes/cast.cc @@ -20,10 +20,8 @@ void Cast::parseAttributes( onnx::NodeProto &node ) void Cast::resolve(void) { - // TODO: should we warn user here. What is the use-case of 'Cast' in embedded systems? - - input = inputs[0]; - register_input(input, "input"); + LOG(INFO) << "'Cast' node found." << std::endl; + register_input(inputs[0], "input"); switch(to) { @@ -37,9 +35,8 @@ void Cast::resolve(void) } Tensor *t = new Tensor; - t->data_dim = input->data_dim; + t->data_dim = inputs[0]->data_dim; t->data_type = static_cast(to); - output = t; register_output(t, "output"); } @@ -47,6 +44,9 @@ void Cast::resolve(void) void Cast::print(std::ostream &dst) const { INDT_1 << "/* Cast */" << std::endl; + const Tensor *input = inputs[0]; + const Tensor *output = outputs[0]; + std::string intype = input->data_type_str(); std::string outtype = output->data_type_str(); diff --git a/src/nodes/cast.h b/src/nodes/cast.h index 3c5003a..85ded67 100644 --- a/src/nodes/cast.h +++ b/src/nodes/cast.h @@ -10,17 +10,12 @@ class Cast : public Node { public: Cast() { op_name = "Cast"; - input=output=NULL; to=-1; } int to; std::string output_type; - // input and output tensors - const Tensor *input; - const Tensor *output; - virtual void parseAttributes( onnx::NodeProto &node ) override; virtual void resolve(void) override; virtual void print(std::ostream &dst) const override; diff --git a/src/nodes/clip.h b/src/nodes/clip.h index b2a29a3..6e48acd 100644 --- a/src/nodes/clip.h +++ b/src/nodes/clip.h @@ -9,7 +9,6 @@ class Clip : public Node { public: Clip() { op_name = "Clip"; - input=min_tensor=max_tensor=output=NULL; min_attr = std::numeric_limits::lowest(); max_attr = std::numeric_limits::max(); } @@ -19,11 +18,6 @@ class Clip : public Node { // works for all versions of Clip float min_attr, max_attr; - const Tensor *input; - const Tensor *min_tensor; - const Tensor *max_tensor; - const Tensor *output; - virtual void parseAttributes( onnx::NodeProto &node ) override { for( const auto& a : node.attribute() ) { @@ -40,55 +34,40 @@ class Clip : public Node { virtual void resolve(void) override { - input = inputs[0]; - + const Tensor *input = inputs[0]; + register_input(inputs[0], "input"); if (inputs.size() > 1 && inputs[1]->is_used()) - min_tensor = inputs[1]; + register_input(inputs[1], "min_tensor"); if (inputs.size() > 2 && inputs[2]->is_used()) - max_tensor = inputs[2]; - + register_input(inputs[2], "max_tensor"); Tensor *t = new Tensor; t->data_dim = input->data_dim; t->data_type = input->data_type; - /* Store the created tensor both as reference in this node, and into - * the return value vector! */ - output = t; - outputs.push_back(t); + register_output(t, "output"); } - - virtual void print_parameters(std::ostream &dst, bool decorate ) const override + virtual void print(std::ostream &dst) const override { - input->print_tensor_as_const(dst, !decorate); - - if (min_tensor) { - dst << ", "; - min_tensor->print_tensor_as_const(dst, !decorate); - } - - if (max_tensor) { - dst << ", "; - max_tensor->print_tensor_as_const(dst, !decorate); - } - - dst << ", "; - output->print_tensor(dst, !decorate); - } + const Tensor *input = inputs[0]; + const Tensor *min_tensor = nullptr; + const Tensor *max_tensor = nullptr; - virtual void print(std::ostream &dst) const override - { + if (inputs.size() > 1 && inputs[1]->is_used()) + min_tensor = inputs[1]; + if (inputs.size() > 2 && inputs[2]->is_used()) + max_tensor = inputs[2]; INDT_1 << "/* Clip */" << std::endl; if( min_tensor ) - INDT_1 << min_tensor->data_type_str() << " minv = " << min_tensor->cname() << "[0];" << std::endl; + INDT_1 << min_tensor->data_type_str() << " minv = min_tensor[0];" << std::endl; else INDT_1 << "float minv = " << min_attr << ";" << std::endl; if( max_tensor ) - INDT_1 << max_tensor->data_type_str() << " maxv = " << max_tensor->cname() << "[0];" << std::endl; + INDT_1 << max_tensor->data_type_str() << " maxv = max_tensor[0];" << std::endl; else INDT_1 << "float maxv = " << max_attr << ";" << std::endl; @@ -102,13 +81,12 @@ class Clip : public Node { idx += "[" + lv + "]"; } - INDT_2 << output->cname() << idx << " = "; - dst << "MAX( MIN( " << input->cname() << idx << ", maxv), minv);" << std::endl; + INDT_2 << "output" << idx << " = "; + dst << "MAX( MIN( input"<< idx << ", maxv), minv);" << std::endl; for( unsigned r=0; rrank(); r++) { INDT_1 << "}" << std::endl; } - } }; } diff --git a/src/nodes/concat.h b/src/nodes/concat.h index 56e0fe9..e8ceb9c 100644 --- a/src/nodes/concat.h +++ b/src/nodes/concat.h @@ -11,27 +11,11 @@ namespace toC { Concat() { op_name = "Concat"; axis = 1; - concat_result = nullptr; } - // inputs - std::vector node_inputs; // 'inputs' in the spec - - // output - const Tensor *concat_result ; - // attribute int axis; - void print_parameters(std::ostream &dst, bool decorate) const override { - size_t input_count = node_inputs.size(); - for (size_t i = 0; i < input_count; i++) { - node_inputs[i]->print_tensor_as_const(dst, !decorate); - dst << ", "; - } - concat_result->print_tensor(dst, !decorate); - } - void parseAttributes(onnx::NodeProto &node) override { for (const auto &a : node.attribute()) { if (a.name() == "axis") { @@ -46,35 +30,39 @@ namespace toC { void print(std::ostream &dst) const override { dst << "\t/* Concat */" << std::endl; + const Tensor *concat_result = outputs[0]; // the axisPitch is the number of elements to add to move to the next split axis in the concat_result int64_t axisPitch = 1; - for (int i = concat_result ->data_dim.size() - 1; i >= axis; i--) { - axisPitch *= concat_result ->data_dim[i]; + for (int i = concat_result->data_dim.size() - 1; i >= axis; i--) { + axisPitch *= concat_result->data_dim[i]; } dst << "\tint64_t outputOffset;" << std::endl; int64_t outputBase = 0; - int64_t input_count = node_inputs.size(); + int64_t input_count = inputs.size(); for (int64_t inputIndex = 0; inputIndex < input_count; inputIndex++) { + std::string input_name = "input_"; + input_name += std::to_string(inputIndex); + // the inputAxisPitch is the number of elements to add to move to the next split axis in the inputs int64_t inputAxisPitch = 1; - for (int i = node_inputs[inputIndex]->data_dim.size() - 1; i >= axis; i--) { - inputAxisPitch *= node_inputs[inputIndex]->data_dim[i]; + for (int i = inputs[inputIndex]->data_dim.size() - 1; i >= axis; i--) { + inputAxisPitch *= inputs[inputIndex]->data_dim[i]; } - int64_t inputSize = node_inputs[inputIndex]->data_num_elem(); + int64_t inputSize = inputs[inputIndex]->data_num_elem(); // copy the data across: for every 'inputAxisPitch' values copied, we move over by the 'axisPitch' dst << "\toutputOffset = " << outputBase << ";" << std::endl; dst << "\tfor (int64_t i = 0, j = 0; i < " << inputSize << "; i++) {" << std::endl; - dst << "\t\t*((" << concat_result ->data_type_str() << "*)" << concat_result ->cname(); + dst << "\t\t*((" << concat_result ->data_type_str() << "*)output"; dst << " + (outputOffset + i)) = *((" << concat_result ->data_type_str() << "*)"; - dst << node_inputs[inputIndex]->cname() << " + i);" << std::endl; + dst << input_name << " + i);" << std::endl; dst << "\t\tif (++j == " << inputAxisPitch << ") {" << std::endl; dst << "\t\t\toutputOffset += (" << axisPitch - inputAxisPitch << ");" << std::endl; @@ -89,7 +77,6 @@ namespace toC { } void resolve(void) override { - node_inputs = inputs; if (inputs.size() == 1 ) { LOG(WARNING) << "Concat node " << onnx_name << " has only one input." << std::endl; } @@ -99,7 +86,7 @@ namespace toC { auto *rv = new Tensor; rv->data_dim = inputs[0]->data_dim; - size_t input_count = node_inputs.size(); + size_t input_count = inputs.size(); size_t output_axis_size = 0; size_t i, j; std::vector dims = inputs[0]->data_dim; @@ -115,12 +102,15 @@ namespace toC { ERROR("Concat's input tensors must have the same shape, except for the " "dimension size of the axis to concatenate on."); } + + std::string input_name = "input_"; + input_name += std::to_string(i); + register_input(inputs[i], input_name); output_axis_size += inputs[i]->data_dim[axis]; } rv->data_dim[axis] = output_axis_size; rv->data_type = inputs[0]->data_type; - concat_result = rv; - outputs.push_back(rv); + register_output(rv, "output"); } }; } diff --git a/src/nodes/constant.h b/src/nodes/constant.h index 7c66834..364d944 100644 --- a/src/nodes/constant.h +++ b/src/nodes/constant.h @@ -10,22 +10,15 @@ class Constant : public Node { public: Constant() { op_name = "Constant"; - output = NULL; } - const Tensor *output; - - - virtual void print_parameters(std::ostream &dst, bool decorate ) const override - { - output->print_tensor(dst, !decorate); - } + Tensor *value_tensor = nullptr; virtual void parseAttributes( onnx::NodeProto &node ) override { for( const auto& a : node.attribute() ) { LOG(TRACE) << "Parsing attribute " << a.name() << std::endl; if( a.name() == "value" ) - output = parse_attribute_tensor(a); + value_tensor = parse_attribute_tensor(a); else ERROR("Unimplemented parsing of attribute " << a.name()); } @@ -36,20 +29,16 @@ class Constant : public Node { { dst << "\t/* Constant */" << std::endl; dst << "\t/* The output is generated as a global tensor */" << std::endl; - dst << "\t(void)"<cname()<< ";" <(output); + if( value_tensor == nullptr ) + ERROR("Constant tensor not resolved"); // "This operator produces a constant tensor." - t->isConst = true; - outputs.push_back(t); + value_tensor->isConst = true; + register_output(value_tensor, "output"); } }; } diff --git a/src/nodes/constantofshape.cc b/src/nodes/constantofshape.cc index c08fa9b..28165e5 100644 --- a/src/nodes/constantofshape.cc +++ b/src/nodes/constantofshape.cc @@ -19,7 +19,7 @@ void ConstantOfShape::parseAttributes( onnx::NodeProto &node ) void ConstantOfShape::resolve(void) { - input = inputs[0]; + Tensor *input = inputs[0]; register_input(input, "input"); Tensor *t = new Tensor; @@ -32,17 +32,17 @@ void ConstantOfShape::resolve(void) t->data_type = value->data_type; else t->data_type = onnx::TensorProto_DataType_FLOAT; - output = t; register_output(t, "output"); } -/* Body of the node implementing function */ void ConstantOfShape::print(std::ostream &dst) const { - INDT_1 << "/* ConstantOfShape */" << std::endl; + Tensor *output = outputs[0]; std::string type = output->data_type_str(); + INDT_1 << "/* ConstantOfShape */" << std::endl; + INDT_1 << type << " *dst = (" << type << "*)output;" << std::endl; INDT_1 << "for( unsigned i=0; i< " << output->data_num_elem() << "; i++)" << std::endl; INDT_2 << "dst[i] = " ; diff --git a/src/nodes/constantofshape.h b/src/nodes/constantofshape.h index f4f7142..de38e63 100644 --- a/src/nodes/constantofshape.h +++ b/src/nodes/constantofshape.h @@ -22,12 +22,9 @@ class ConstantOfShape : public Node { public: ConstantOfShape() { op_name = "ConstantOfShape"; - input=output=value=NULL; + value=NULL; } - const Tensor *input; - const Tensor *output; - // Attribute, not input const Tensor *value; diff --git a/src/nodes/conv.h b/src/nodes/conv.h index 3b1ba49..c4b2794 100644 --- a/src/nodes/conv.h +++ b/src/nodes/conv.h @@ -11,11 +11,7 @@ class Conv : public SpatialFilter { public: Conv() { op_name = "Conv"; - b=NULL; } - /* Conv node specific attributes */ - // optional inputs - const Tensor *b; virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override { @@ -23,7 +19,7 @@ class Conv : public SpatialFilter { for(unsigned i=0; irank()-2; i++) outidx += "[o" + std::to_string(i) + "]"; INDT_3 << "y[b][m]" << outidx << " = "; - if( b == NULL ) + if( inputs.size() < 3 ) // bias is the 3rd input, optional dst << "0;" << std::endl; else dst << "bias[m];" << std::endl; @@ -64,17 +60,12 @@ class Conv : public SpatialFilter { w = inputs[1]; // weights register_input(w,"w"); if( inputs.size() == 3 ) { - b = inputs[2]; - register_input(b,"bias"); + register_input(inputs[2],"bias"); } - else - b = NULL; if( typeConstraint_highPrecisionNumeric(x) == false ||typeConstraint_highPrecisionNumeric(w) == false) ERROR("Incorrect input for node"); - if( b && (typeConstraint_highPrecisionNumeric(b) == false) ) - ERROR("Incorrect input for node"); resolve_strides(); resolve_dilations(); diff --git a/src/nodes/convinteger.h b/src/nodes/convinteger.h index 98754cd..c29b3f5 100644 --- a/src/nodes/convinteger.h +++ b/src/nodes/convinteger.h @@ -18,29 +18,7 @@ class ConvInteger : public SpatialFilter { op_name = "ConvInteger"; auto_pad = "NOTSET"; group = 1; - x=w=x_zero_point=w_zero_point=y=NULL; - } - /* ConvInteger node specific attributes */ - - // optional inputs - const Tensor *x_zero_point; - const Tensor *w_zero_point; - - virtual void print_parameters(std::ostream &dst, bool decorate ) const override - { - x->print_tensor_as_const(dst, !decorate); - dst << ", "; - w->print_tensor_as_const(dst, !decorate); - dst << ", "; - if( x_zero_point ) { - x_zero_point->print_tensor_as_const(dst, !decorate); - dst << ", "; - } - if( w_zero_point ) { - w_zero_point->print_tensor_as_const(dst, !decorate); - dst << ", "; - } - y->print_tensor(dst, !decorate); + x=w=y=NULL; } virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override @@ -48,7 +26,7 @@ class ConvInteger : public SpatialFilter { if( options.quantize ) INDT_3 << "int32_t cell = 0;" << std::endl; else - INDT_3 << y->cname() << "[b][m][o0][o1] = 0;" << std::endl; + INDT_3 << "y[b][m][o0][o1] = 0;" << std::endl; } virtual void print_output_cell_calc( @@ -58,19 +36,19 @@ class ConvInteger : public SpatialFilter { const std::string &y_idx) const override { std::string x_zero; - if( x_zero_point ) - x_zero = constant_acces_code( x_zero_point->cname() + "[0]"); + if( inputs.size() >= 3 ) // x_zero_point is optional, 3rd input + x_zero = constant_acces_code( "x_zero_point[0]"); else x_zero = "0"; - INDT_4 << w->data_type_str() << " w = " << constant_acces_code( w->cname() + "[m][c][k0][k1]") << ";" << std::endl; + INDT_4 << w->data_type_str() << " w_ = " << constant_acces_code("w[m][c][k0][k1]") << ";" << std::endl; std::string dest; if( options.quantize ) dest = "cell"; else - dest = y->cname() + "[b][m][o0][o1]"; + dest = "y[b][m][o0][o1]"; - INDT_4 << dest << "+= ("<< x->cname() << "[b][c][i0+k0][i1+k1] - " << x_zero << ") * w;" << std::endl; + INDT_4 << dest << "+= (x[b][c][i0+k0][i1+k1] - " << x_zero << ") * w_;" << std::endl; } virtual void print_output_cell_finalize(std::ostream &dst, const std::string &y_idx) const override @@ -81,7 +59,7 @@ class ConvInteger : public SpatialFilter { INDT_3 << "int32_t tmp = cell/" << divisor << ";" << std::endl; INDT_3 << "tmp = tmp > 127?127:tmp;" << std::endl; INDT_3 << "tmp = tmp < -127?-127:tmp;" << std::endl; - INDT_3 << y->cname() + "[b][m][o0][o1] = tmp;" << std::endl; + INDT_3 << "y[b][m][o0][o1] = tmp;" << std::endl; } } @@ -95,12 +73,14 @@ class ConvInteger : public SpatialFilter { virtual void resolve(void) override { x = inputs[0]; // data + register_input(x, "x"); w = inputs[1]; // weights + register_input(w, "w"); if( inputs.size() > 2 ) - x_zero_point = inputs[2]; + register_input(inputs[2], "x_zero_point"); if( inputs.size() > 3 ){ - w_zero_point = inputs[3]; + register_input(inputs[3], "w_zero_point"); ERROR("unimplemented: weight zero points"); } @@ -128,7 +108,7 @@ class ConvInteger : public SpatialFilter { else rv->data_type = onnx::TensorProto_DataType_INT32; y=rv; - outputs.push_back(rv); + register_output(rv, "y"); } }; } diff --git a/src/nodes/elementwise_2.h b/src/nodes/elementwise_2.h index fa7968c..e87e383 100644 --- a/src/nodes/elementwise_2.h +++ b/src/nodes/elementwise_2.h @@ -12,10 +12,6 @@ class Elementwise_2 : public Node { std::function operation = [](const std::string &a, const std::string &b){ ERROR("onnx2c internal error"); return ""; }; - // input and output: C = A ? B - const Tensor *A; - const Tensor *B; - const Tensor *C; bool output_is_bool; // Union of attributes over implemented nodes @@ -24,7 +20,6 @@ class Elementwise_2 : public Node { Elementwise_2(std::string op) { op_name = op; - A=B=C=NULL; output_is_bool = false; fmod=0; shift_dir="NOT_GIVEN"; // mandatory for BitShift, but no default @@ -109,13 +104,13 @@ class Elementwise_2 : public Node { virtual void print_parameters(std::ostream &dst, bool decorate ) const override { - A->print_tensor_as_const(dst, !decorate); + inputs[0]->print_tensor_as_const(dst, !decorate); dst << ", "; - B->print_tensor_as_const(dst, !decorate); + inputs[1]->print_tensor_as_const(dst, !decorate); dst << ", "; - C->print_tensor(dst, !decorate); + outputs[0]->print_tensor(dst, !decorate); } virtual void parseAttributes( onnx::NodeProto &node ) override { @@ -141,6 +136,11 @@ class Elementwise_2 : public Node { INDT_1 << " fmod: " << fmod << std::endl; INDT_1 << " */" << std::endl; + // C = A ? B + Tensor *A = inputs[0]; + Tensor *B = inputs[1]; + Tensor *C = outputs[0]; + // if either A or B does not have enough dimensions, prepend // dimensions of 1 to match rank of C std::vector padA = A->data_dim; @@ -190,8 +190,8 @@ class Elementwise_2 : public Node { virtual void resolve(void) override { - A = inputs[0]; - B = inputs[1]; + Tensor *A = inputs[0]; + Tensor *B = inputs[1]; std::vector result_dim; multidirectional_broadcast_size(A->data_dim, B->data_dim, result_dim); @@ -202,7 +202,6 @@ class Elementwise_2 : public Node { t->data_type = onnx::TensorProto_DataType_BOOL; else t->data_type = A->data_type; - C = t; outputs.push_back(t); } };