diff --git a/src/nodes/averagepool.h b/src/nodes/averagepool.h index 1818cc8..a650d59 100644 --- a/src/nodes/averagepool.h +++ b/src/nodes/averagepool.h @@ -17,7 +17,7 @@ class AveragePool : public Pooling { virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override { - INDT_3 << y->data_type_str() << " curavg = 0.0;" << std::endl; + INDT_3 << get_Y()->data_type_str() << " curavg = 0.0;" << std::endl; INDT_3 << "int numavg = 0;" << std::endl; } virtual void print_output_cell_calc( @@ -52,12 +52,7 @@ class AveragePool : public Pooling { virtual void resolve(void) override { - x = inputs[0]; - register_input(x, "x"); - - if( !( typeConstraint_plainFloatingPoints(x) - ||typeConstraint_8bit(x)) ) - ERROR("Incorrect input for node"); + register_input(inputs[0], "x"); resolve_strides(); resolve_dilations(); @@ -66,8 +61,7 @@ class AveragePool : public Pooling { Tensor *rv = new Tensor; rv->data_dim = resolve_output_size(); - rv->data_type = x->data_type; - y=rv; + rv->data_type = get_X()->data_type; register_output(rv, "y"); update_pads(); diff --git a/src/nodes/conv.h b/src/nodes/conv.h index c4b2794..9bd641f 100644 --- a/src/nodes/conv.h +++ b/src/nodes/conv.h @@ -16,7 +16,7 @@ class Conv : public SpatialFilter { virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override { std::string outidx=""; - for(unsigned i=0; irank()-2; i++) + for(unsigned i=0; irank()-2; i++){ + for(unsigned i=0; idata_dim = resolve_output_size(); - rv->data_type = x->data_type; - y=rv; + rv->data_type = get_X()->data_type; register_output(rv, "y"); } }; diff --git a/src/nodes/convinteger.h b/src/nodes/convinteger.h index c29b3f5..777841d 100644 --- a/src/nodes/convinteger.h +++ b/src/nodes/convinteger.h @@ -18,7 +18,6 @@ class ConvInteger : public SpatialFilter { op_name = "ConvInteger"; auto_pad = "NOTSET"; group = 1; - x=w=y=NULL; } virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override @@ -41,7 +40,7 @@ class ConvInteger : public SpatialFilter { else x_zero = "0"; - INDT_4 << w->data_type_str() << " w_ = " << constant_acces_code("w[m][c][k0][k1]") << ";" << std::endl; + INDT_4 << get_W()->data_type_str() << " w_ = " << constant_acces_code("w[m][c][k0][k1]") << ";" << std::endl; std::string dest; if( options.quantize ) dest = "cell"; @@ -72,10 +71,8 @@ class ConvInteger : public SpatialFilter { virtual void resolve(void) override { - x = inputs[0]; // data - register_input(x, "x"); - w = inputs[1]; // weights - register_input(w, "w"); + register_input(inputs[0], "x"); + register_input(inputs[1], "w"); if( inputs.size() > 2 ) register_input(inputs[2], "x_zero_point"); @@ -84,7 +81,7 @@ class ConvInteger : public SpatialFilter { ERROR("unimplemented: weight zero points"); } - if( x->data_dim.size() != 4 ) + if( get_X()->data_dim.size() != 4 ) ERROR("Unimplemented: ConvInteger for non 2D images"); @@ -107,7 +104,6 @@ class ConvInteger : public SpatialFilter { rv->data_type = onnx::TensorProto_DataType_INT8; else rv->data_type = onnx::TensorProto_DataType_INT32; - y=rv; register_output(rv, "y"); } }; diff --git a/src/nodes/maxpool.h b/src/nodes/maxpool.h index a1975a5..12abb24 100644 --- a/src/nodes/maxpool.h +++ b/src/nodes/maxpool.h @@ -12,25 +12,10 @@ class MaxPool : public Pooling { public: MaxPool() { op_name = "MaxPool"; - Indices=NULL; } - // optional outputs - const Tensor *Indices; - std::vector pad_shapes; - virtual void print_parameters(std::ostream &dst, bool decorate ) const override - { - x->print_tensor_as_const(dst, !decorate); - dst << ", "; - y->print_tensor(dst, !decorate); - if( Indices->name != "" ) { - dst << ", "; - Indices->print_tensor(dst, !decorate); - } - } - virtual void parseAttributes( onnx::NodeProto &node ) override { Pooling::parseAttributes(node); @@ -43,7 +28,7 @@ class MaxPool : public Pooling { virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override { - std::string type = x->data_type_str(); + std::string type = get_X()->data_type_str(); std::string type_min_value; if( type == "float" ) type_min_value = "-FLT_MAX"; @@ -57,7 +42,7 @@ class MaxPool : public Pooling { ERROR("Unimplemented: minimum value for this type"); INDT_3 << type << " curmax = " << type_min_value << ";" << std::endl; - if( Indices->name != "" ) + if( get_Indices() ) INDT_3 << "int64_t curmaxind = -1;" << std::endl; @@ -68,7 +53,9 @@ class MaxPool : public Pooling { const std::string &w_idx, const std::string &y_idx) const override { - unsigned n_data_dims = x->data_dim.size()-2; + unsigned n_data_dims = get_numDataDim(); + const Tensor *x = get_X(); + // Calculate how much one index means in terms of the Indices output. // Generate helper string for the next step. std::vectorsize_of_dim(x->rank()); @@ -80,9 +67,9 @@ class MaxPool : public Pooling { indices_value += "+(ii" + std::to_string(i) + "*" + std::to_string(size_of_dim[i+2]) + ")"; // Update the max and index value - INDT_4 << "if( curmax < " << x->cname() << x_idx << ") {" <cname() << x_idx << ");" <name != "" ) + INDT_4 << "if( curmax < x" << x_idx << ") {" <cname() << y_idx << "= curmax;" << std::endl; - if( Indices->name != "" ) - INDT_3 << Indices->cname() << y_idx << "= curmaxind;" << std::endl; + INDT_3 << "y" << y_idx << "= curmax;" << std::endl; + if( get_Indices() ) + INDT_3 << "Indices " << y_idx << "= curmaxind;" << std::endl; } @@ -105,7 +92,7 @@ class MaxPool : public Pooling { virtual void resolve(void) override { - x = inputs[0]; + register_input(inputs[0], "x"); resolve_strides(); resolve_dilations(); @@ -118,9 +105,8 @@ class MaxPool : public Pooling { Tensor *rv = new Tensor; rv->data_dim = resolve_output_size(); - rv->data_type = x->data_type; - y=rv; - outputs.push_back(rv); + rv->data_type = get_X()->data_type; + register_output(rv, "y"); update_pads(); @@ -128,10 +114,13 @@ class MaxPool : public Pooling { Tensor *indices_out = new Tensor; indices_out->data_type = onnx::TensorProto_DataType::TensorProto_DataType_INT64; indices_out->data_dim = rv->data_dim; - Indices = indices_out; - outputs.push_back(indices_out); - - + register_output(indices_out, "Indices"); + } + const Tensor* get_Indices(void) const { + if( outputs[1]->name != "" ) + return outputs[1]; + else + return nullptr; } }; } diff --git a/src/nodes/pooling.h b/src/nodes/pooling.h index 9f58c79..8bd83f2 100644 --- a/src/nodes/pooling.h +++ b/src/nodes/pooling.h @@ -46,19 +46,19 @@ class Pooling : public SpatialFilter { virtual std::vector resolve_output_size(void) override { std::vector rv; - rv.push_back(x->data_dim[0]);//batch - rv.push_back(x->data_dim[1]);//channel + rv.push_back(get_X()->data_dim[0]);//batch + rv.push_back(get_X()->data_dim[1]);//channel - unsigned data_dims = x->data_dim.size()-2; + unsigned data_dims = get_numDataDim(); std::vector pad_shapes; for( unsigned i=0; idata_dim.size(); i++ ) { + for( unsigned i=2; idata_dim.size(); i++ ) { int d; - int in_dim = x->data_dim[i]; + int in_dim = get_X()->data_dim[i]; int kernel = kernel_shape[i-2]; int dilation = dilations.size()==0 ? 1 : dilations[i-2]; int stride = strides[i-2]; @@ -92,7 +92,7 @@ class Pooling : public SpatialFilter { if ( auto_pad == "VALID" ) return; - unsigned data_dims = x->data_dim.size()-2; + unsigned data_dims = get_numDataDim(); // Calculate pads for the "SAME_*" cases that need the output shape for(unsigned i=0; idata_dim[i+2]; - int output_size = y->data_dim[i+2]; + int input_size = get_X()->data_dim[i+2]; + int output_size = get_Y()->data_dim[i+2]; int pad_shape = (output_size - 1) * strides[i] + (( kernel_shape[i] -1) * dilations[i]+1) - input_size; pads[i] = pad_shape/2; pads[i+data_dims] = pad_shape/2; diff --git a/src/nodes/spatialfilter.h b/src/nodes/spatialfilter.h index 5d420fd..96bd35a 100644 --- a/src/nodes/spatialfilter.h +++ b/src/nodes/spatialfilter.h @@ -17,14 +17,8 @@ class SpatialFilter : public Node { public: SpatialFilter() { auto_pad = "NOTSET"; - x=w=y=NULL; group=1; } - // inputs - const Tensor *x; - const Tensor *w; - // outputs - const Tensor *y; // Attributes std::vector kernel_shape; @@ -34,6 +28,16 @@ class SpatialFilter : public Node { std::vector pads; std::vector strides; + const Tensor* get_X(void) const { return inputs[0]; } + const Tensor* get_W(void) const { + if( inputs.size() > 1 ) + return inputs[1]; + else + return nullptr; + } + const Tensor* get_Y(void) const { return outputs[0]; } + uint32_t get_numDataDim(void) const {return get_X()->rank() - 2; } + virtual void parseAttributes( onnx::NodeProto &node ) override { for( const auto& a : node.attribute() ) { if( a.name() == "auto_pad" ) @@ -54,9 +58,8 @@ class SpatialFilter : public Node { void resolve_strides(void) { - unsigned num_data_dim = x->rank()-2; if( strides.size() == 0 ) - for( unsigned i=0; irank(); i++) - kernel_shape.push_back(w->data_dim[i]); + for( unsigned i=2; irank(); i++) + kernel_shape.push_back(get_W()->data_dim[i]); } } void resolve_dilations(void) { - unsigned num_data_dim = x->rank()-2; if( dilations.size() == 0 ) - for( unsigned i=0; i< num_data_dim; i++ ) + for( unsigned i=0; i< get_numDataDim(); i++ ) dilations.push_back(1); } void resolve_pads(void) { - unsigned num_data_dim = x->rank()-2; + unsigned num_data_dim = get_numDataDim(); if( pads.size() == 0 ) { pads.resize(num_data_dim*2); for( unsigned i=0; i< num_data_dim; i++ ) { @@ -100,13 +102,14 @@ class SpatialFilter : public Node { virtual std::vector resolve_output_size(void) { std::vector rv; - unsigned num_data_dim = x->rank()-2; - rv.push_back(x->data_dim[0]);//batch size - rv.push_back(w->data_dim[0]);//"number of feature maps" + unsigned num_data_dim = get_numDataDim(); + rv.push_back(get_X()->data_dim[0]);//batch size + rv.push_back(get_W()->data_dim[0]);//"number of feature maps" - for( unsigned xdim=2; xdim < x->data_dim.size(); xdim++) { + for( unsigned dim=0, xdim=2; + dim < num_data_dim; + dim++, xdim++) { int outdim; - unsigned dim = xdim-2; // Not sure if the naming is correct. Here // kernel: the (number of) weights of the filter // filter: the spatial placement of the kernel weights @@ -118,10 +121,10 @@ class SpatialFilter : public Node { // SAME_UPPER or SAME_LOWER mean pad the input so that the output spatial size match the input. // "match" here means "is equal". if( auto_pad == "SAME_UPPER" || auto_pad == "SAME_LOWER" ) - outdim = x->data_dim[xdim]; + outdim = get_X()->data_dim[xdim]; else if( auto_pad == "NOTSET" || auto_pad == "VALID") { //padded input - int input_size = x->data_dim[xdim] + pads[dim]+pads[dim+num_data_dim]; + int input_size = get_X()->data_dim[xdim] + pads[dim]+pads[dim+num_data_dim]; // [ 0 1 2 3 4 5 6 7 8 9 ] // |kern=3| // last output=7 @@ -185,10 +188,10 @@ class SpatialFilter : public Node { virtual void print_output_cell_finalize(std::ostream &dst, const std::string &y_idx="") const = 0; void print_loop_with_padding_checks(std::ostream &dst) const { - unsigned n_data_dims = x->data_dim.size() -2; - unsigned batch_size = x->data_dim[0]; - unsigned channels = x->data_dim[1]; - unsigned maps=y->data_dim[1]; + unsigned n_data_dims = get_numDataDim(); + unsigned batch_size = get_X()->data_dim[0]; + unsigned channels = get_X()->data_dim[1]; + unsigned maps=get_Y()->data_dim[1]; /* Create various indexing strings. This makes generating the loops much cleaner, * and makes possible the code sharing in child classes. */ @@ -213,7 +216,7 @@ class SpatialFilter : public Node { } if( direct_channel_map() ) INDT_1 << "for( uint32_t m=0, c=0; m<" << maps << "; m++, c=m) {" << std::endl; - else if( w && group > 1 ) { + else if( get_W() && group > 1 ) { INDT_1 << "uint32_t go = " << maps/group << "; // output group size, i.e. maps/group" << std::endl; INDT_1 << "uint32_t gi = " << channels/group << "; // inptput group size, i.e. channels/group" << std::endl; INDT_1 << "for( uint32_t g=0; g<" << group << "; g++) {" << std::endl; @@ -229,7 +232,7 @@ class SpatialFilter : public Node { std::string i_idx = "i" + std::to_string(i); INDT_2 << "for( int32_t " << o_idx << "=0, "; dst << i_idx << "=" << -pads[i] << "; "; - dst << o_idx << "<" << y->data_dim[2+i] << "; "; + dst << o_idx << "<" << get_Y()->data_dim[2+i] << "; "; dst << o_idx <<"++, "<< i_idx << "+=" << strides[i] << ") {" << std::endl; } @@ -237,7 +240,7 @@ class SpatialFilter : public Node { if (direct_channel_map()) ; - else if( w && group > 1 ) + else if( get_W() && group > 1 ) INDT_3 << "for( int32_t c=gi*g; c=" << x->data_dim[2+i] << ") continue;" << std::endl; + INDT_4 << "if( ii" << i_str << ">=" << get_X()->data_dim[2+i] << ") continue;" << std::endl; } print_output_cell_calc(dst, in_kern_idxs, "", y_idx);