Skip to content

Commit

Permalink
Remove local tensor copies: spatialfilter+kids
Browse files Browse the repository at this point in the history
- averagepool
- conv
- convinteger
- maxpool
  • Loading branch information
kraiskil committed Jul 26, 2023
1 parent a167f31 commit 14f2295
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 97 deletions.
12 changes: 3 additions & 9 deletions src/nodes/averagepool.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AveragePool : public Pooling {

virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override
{
INDT_3 << y->data_type_str() << " curavg = 0.0;" << std::endl;
INDT_3 << get_Y()->data_type_str() << " curavg = 0.0;" << std::endl;
INDT_3 << "int numavg = 0;" << std::endl;
}
virtual void print_output_cell_calc(
Expand Down Expand Up @@ -52,12 +52,7 @@ class AveragePool : public Pooling {

virtual void resolve(void) override
{
x = inputs[0];
register_input(x, "x");

if( !( typeConstraint_plainFloatingPoints(x)
||typeConstraint_8bit(x)) )
ERROR("Incorrect input for node");
register_input(inputs[0], "x");

resolve_strides();
resolve_dilations();
Expand All @@ -66,8 +61,7 @@ class AveragePool : public Pooling {

Tensor *rv = new Tensor;
rv->data_dim = resolve_output_size();
rv->data_type = x->data_type;
y=rv;
rv->data_type = get_X()->data_type;
register_output(rv, "y");

update_pads();
Expand Down
17 changes: 5 additions & 12 deletions src/nodes/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Conv : public SpatialFilter {
virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override
{
std::string outidx="";
for(unsigned i=0; i<x->rank()-2; i++)
for(unsigned i=0; i<get_numDataDim(); i++)
outidx += "[o" + std::to_string(i) + "]";
INDT_3 << "y[b][m]" << outidx << " = ";
if( inputs.size() < 3 ) // bias is the 3rd input, optional
Expand All @@ -33,7 +33,7 @@ class Conv : public SpatialFilter {
std::string outidx="";
std::string iididx="";
std::string kidx="";
for(unsigned i=0; i<x->rank()-2; i++){
for(unsigned i=0; i<get_numDataDim(); i++){
outidx += "[o" + std::to_string(i) + "]";
iididx+= "[ii" + std::to_string(i) + "]";
kidx+= "[k" + std::to_string(i) + "]";
Expand All @@ -55,27 +55,20 @@ class Conv : public SpatialFilter {

virtual void resolve(void) override
{
x = inputs[0]; // data
register_input(x,"x");
w = inputs[1]; // weights
register_input(w,"w");
register_input(inputs[0],"x");
register_input(inputs[1],"w");
if( inputs.size() == 3 ) {
register_input(inputs[2],"bias");
}

if( typeConstraint_highPrecisionNumeric(x) == false
||typeConstraint_highPrecisionNumeric(w) == false)
ERROR("Incorrect input for node");

resolve_strides();
resolve_dilations();
resolve_pads();
resolve_kernel_shape();

Tensor *rv = new Tensor;
rv->data_dim = resolve_output_size();
rv->data_type = x->data_type;
y=rv;
rv->data_type = get_X()->data_type;
register_output(rv, "y");
}
};
Expand Down
12 changes: 4 additions & 8 deletions src/nodes/convinteger.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ConvInteger : public SpatialFilter {
op_name = "ConvInteger";
auto_pad = "NOTSET";
group = 1;
x=w=y=NULL;
}

virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override
Expand All @@ -41,7 +40,7 @@ class ConvInteger : public SpatialFilter {
else
x_zero = "0";

INDT_4 << w->data_type_str() << " w_ = " << constant_acces_code("w[m][c][k0][k1]") << ";" << std::endl;
INDT_4 << get_W()->data_type_str() << " w_ = " << constant_acces_code("w[m][c][k0][k1]") << ";" << std::endl;
std::string dest;
if( options.quantize )
dest = "cell";
Expand Down Expand Up @@ -72,10 +71,8 @@ class ConvInteger : public SpatialFilter {

virtual void resolve(void) override
{
x = inputs[0]; // data
register_input(x, "x");
w = inputs[1]; // weights
register_input(w, "w");
register_input(inputs[0], "x");
register_input(inputs[1], "w");

if( inputs.size() > 2 )
register_input(inputs[2], "x_zero_point");
Expand All @@ -84,7 +81,7 @@ class ConvInteger : public SpatialFilter {
ERROR("unimplemented: weight zero points");
}

if( x->data_dim.size() != 4 )
if( get_X()->data_dim.size() != 4 )
ERROR("Unimplemented: ConvInteger for non 2D images");


Expand All @@ -107,7 +104,6 @@ class ConvInteger : public SpatialFilter {
rv->data_type = onnx::TensorProto_DataType_INT8;
else
rv->data_type = onnx::TensorProto_DataType_INT32;
y=rv;
register_output(rv, "y");
}
};
Expand Down
53 changes: 21 additions & 32 deletions src/nodes/maxpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,10 @@ class MaxPool : public Pooling {
public:
MaxPool() {
op_name = "MaxPool";
Indices=NULL;
}

// optional outputs
const Tensor *Indices;

std::vector<int> pad_shapes;

virtual void print_parameters(std::ostream &dst, bool decorate ) const override
{
x->print_tensor_as_const(dst, !decorate);
dst << ", ";
y->print_tensor(dst, !decorate);
if( Indices->name != "" ) {
dst << ", ";
Indices->print_tensor(dst, !decorate);
}
}

virtual void parseAttributes( onnx::NodeProto &node ) override {

Pooling::parseAttributes(node);
Expand All @@ -43,7 +28,7 @@ class MaxPool : public Pooling {

virtual void print_output_cell_init(std::ostream &dst, const std::string &y_idx) const override
{
std::string type = x->data_type_str();
std::string type = get_X()->data_type_str();
std::string type_min_value;
if( type == "float" )
type_min_value = "-FLT_MAX";
Expand All @@ -57,7 +42,7 @@ class MaxPool : public Pooling {
ERROR("Unimplemented: minimum value for this type");

INDT_3 << type << " curmax = " << type_min_value << ";" << std::endl;
if( Indices->name != "" )
if( get_Indices() )
INDT_3 << "int64_t curmaxind = -1;" << std::endl;


Expand All @@ -68,7 +53,9 @@ class MaxPool : public Pooling {
const std::string &w_idx,
const std::string &y_idx) const override
{
unsigned n_data_dims = x->data_dim.size()-2;
unsigned n_data_dims = get_numDataDim();
const Tensor *x = get_X();

// Calculate how much one index means in terms of the Indices output.
// Generate helper string for the next step.
std::vector<int>size_of_dim(x->rank());
Expand All @@ -80,9 +67,9 @@ class MaxPool : public Pooling {
indices_value += "+(ii" + std::to_string(i) + "*" + std::to_string(size_of_dim[i+2]) + ")";

// Update the max and index value
INDT_4 << "if( curmax < " << x->cname() << x_idx << ") {" <<std::endl;
INDT_4 << "curmax = MAX( curmax, " << x->cname() << x_idx << ");" <<std::endl;
if( Indices->name != "" )
INDT_4 << "if( curmax < x" << x_idx << ") {" <<std::endl;
INDT_4 << "curmax = MAX( curmax, x" << x_idx << ");" <<std::endl;
if( get_Indices() )
INDT_4 << "curmaxind = " << indices_value << ";" <<std::endl;
INDT_4 << "}" << std::endl;

Expand All @@ -91,9 +78,9 @@ class MaxPool : public Pooling {
virtual void print_output_cell_finalize(std::ostream &dst, const std::string &y_idx) const override
{
// Store the calculated values into output tensors
INDT_3 << y->cname() << y_idx << "= curmax;" << std::endl;
if( Indices->name != "" )
INDT_3 << Indices->cname() << y_idx << "= curmaxind;" << std::endl;
INDT_3 << "y" << y_idx << "= curmax;" << std::endl;
if( get_Indices() )
INDT_3 << "Indices " << y_idx << "= curmaxind;" << std::endl;
}


Expand All @@ -105,7 +92,7 @@ class MaxPool : public Pooling {

virtual void resolve(void) override
{
x = inputs[0];
register_input(inputs[0], "x");

resolve_strides();
resolve_dilations();
Expand All @@ -118,20 +105,22 @@ class MaxPool : public Pooling {
Tensor *rv = new Tensor;
rv->data_dim = resolve_output_size();

rv->data_type = x->data_type;
y=rv;
outputs.push_back(rv);
rv->data_type = get_X()->data_type;
register_output(rv, "y");

update_pads();

// optional indices vector
Tensor *indices_out = new Tensor;
indices_out->data_type = onnx::TensorProto_DataType::TensorProto_DataType_INT64;
indices_out->data_dim = rv->data_dim;
Indices = indices_out;
outputs.push_back(indices_out);


register_output(indices_out, "Indices");
}
const Tensor* get_Indices(void) const {
if( outputs[1]->name != "" )
return outputs[1];
else
return nullptr;
}
};
}
16 changes: 8 additions & 8 deletions src/nodes/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ class Pooling : public SpatialFilter {
virtual std::vector<int> resolve_output_size(void) override
{
std::vector<int> rv;
rv.push_back(x->data_dim[0]);//batch
rv.push_back(x->data_dim[1]);//channel
rv.push_back(get_X()->data_dim[0]);//batch
rv.push_back(get_X()->data_dim[1]);//channel

unsigned data_dims = x->data_dim.size()-2;
unsigned data_dims = get_numDataDim();
std::vector<int> pad_shapes;
for( unsigned i=0; i<data_dims; i++ ) {
pad_shapes.push_back(pads[i]+pads[data_dims+i]);
}
// Calculate output shape. Pads are now calculated
// for those auto_pad modes that need them.
for( unsigned i=2; i<x->data_dim.size(); i++ ) {
for( unsigned i=2; i<get_X()->data_dim.size(); i++ ) {
int d;
int in_dim = x->data_dim[i];
int in_dim = get_X()->data_dim[i];
int kernel = kernel_shape[i-2];
int dilation = dilations.size()==0 ? 1 : dilations[i-2];
int stride = strides[i-2];
Expand Down Expand Up @@ -92,7 +92,7 @@ class Pooling : public SpatialFilter {
if ( auto_pad == "VALID" )
return;

unsigned data_dims = x->data_dim.size()-2;
unsigned data_dims = get_numDataDim();

// Calculate pads for the "SAME_*" cases that need the output shape
for(unsigned i=0; i<data_dims; i++) {
Expand All @@ -101,8 +101,8 @@ class Pooling : public SpatialFilter {
// The auto_pad attribute for AveragePool is deprecated anyway. Probably just for this confusion.
// This tries to be some sort of band-aid: assume the output size is the same as input size
// which is the usual(?) reason to use "same" padding on the network design level.
int input_size = x->data_dim[i+2];
int output_size = y->data_dim[i+2];
int input_size = get_X()->data_dim[i+2];
int output_size = get_Y()->data_dim[i+2];
int pad_shape = (output_size - 1) * strides[i] + (( kernel_shape[i] -1) * dilations[i]+1) - input_size;
pads[i] = pad_shape/2;
pads[i+data_dims] = pad_shape/2;
Expand Down
Loading

0 comments on commit 14f2295

Please sign in to comment.