Skip to content

Commit

Permalink
Remove local tensor copies: matmul, matmulinteger
Browse files Browse the repository at this point in the history
  • Loading branch information
kraiskil committed Jul 26, 2023
1 parent 168a87e commit a167f31
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 54 deletions.
14 changes: 4 additions & 10 deletions src/nodes/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@ class MatMul : public Node {
public:
MatMul() {
op_name = "MatMul";
A=B=Y=NULL;
}
// inputs
const Tensor *A;
const Tensor *B;
// outputs
const Tensor *Y;


virtual void print(std::ostream &dst) const override
{
Tensor *A = inputs[0];
Tensor *B = inputs[1];
std::string type = A->data_type_str();

if( A->data_dim.size() != 2 )
Expand Down Expand Up @@ -44,8 +39,8 @@ class MatMul : public Node {
}
virtual void resolve(void) override
{
A = inputs[0];
B = inputs[1];
Tensor *A = inputs[0];
Tensor *B = inputs[1];
register_input(A, "A");
register_input(B, "B");
if( typeConstraint_highPrecisionNumeric(A) == false )
Expand All @@ -61,7 +56,6 @@ class MatMul : public Node {
rv->data_dim.push_back(rows);
rv->data_dim.push_back(cols);
rv->data_type = A->data_type;
Y=rv;
register_output(rv, "Y");
}

Expand Down
64 changes: 20 additions & 44 deletions src/nodes/matmulinteger.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,13 @@ class MatMulInteger : public Node {
public:
MatMulInteger() {
op_name = "MatMulInteger";
A=B=Y=NULL;
a_zero_point=b_zero_point=NULL;
}
// inputs
const Tensor *A;
const Tensor *B;
// optional inputs
const Tensor *a_zero_point;
const Tensor *b_zero_point;
// outputs
const Tensor *Y;

virtual void print_parameters(std::ostream &dst, bool decorate ) const override
{
A->print_tensor_as_const(dst, !decorate);
dst << ", ";
B->print_tensor_as_const(dst, !decorate);
dst << ", ";
if( a_zero_point ) {
a_zero_point->print_tensor_as_const(dst, !decorate);
dst << ", ";
}
if( b_zero_point ) {
b_zero_point->print_tensor_as_const(dst, !decorate);
dst << ", ";
}
Y->print_tensor(dst, !decorate);
}


virtual void print(std::ostream &dst) const override
{
Tensor *A = inputs[0];
Tensor *B = inputs[1];
Tensor *Y = outputs[0];
std::string intype = A->data_type_str();
std::string outtype = Y->data_type_str();
std::string weighttype = B->data_type_str();
Expand All @@ -68,24 +43,27 @@ class MatMulInteger : public Node {
if( inner != inner2 )
ERROR("MatMulInteger input's inner dimensions don't match");

if( a_zero_point )
a_zero = a_zero_point->cname() + "[0]";
if( inputs.size() > 2)
a_zero = "a_zero_point[0]";
else
a_zero = "0";
if( b_zero_point )
b_zero = b_zero_point->cname() + "[0]";
if( inputs.size() > 3)
b_zero = "b_zero_point[0]";
else
b_zero = "0";

INDT_1 "/*MatMulInteger*/" << std::endl;
INDT_1 << intype << " *A = (" << intype << "*)" << A->cname() << ";" << std::endl;
INDT_1 << weighttype << " *B = (" << weighttype << "*)" << B->cname() << ";" << std::endl;
INDT_1 << outtype << " *Y = (" << outtype << "*)" << Y->cname() << ";" << std::endl;
INDT_1 << intype << " *A = (" << intype << "*)input_A;" << std::endl;
INDT_1 << weighttype << " *B = (" << weighttype << "*)input_B;" << std::endl;
INDT_1 << outtype << " *Y = (" << outtype << "*)output_Y;" << std::endl;

INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_2 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;


// NB: quantization here is the experimental ONNXC quantization
// that is not only integers, but also scales the output to 8bits.
// This quantization terribly kludgy, and really should be removed
if( options.quantize )
INDT_3 << "int32_t sum = 0;" << std::endl;
else
Expand All @@ -106,23 +84,22 @@ class MatMulInteger : public Node {
}

INDT_2 "}" << std::endl;

}

virtual void resolve(void) override
{
A = inputs[0];
B = inputs[1];
register_input(inputs[0], "input_A");
register_input(inputs[1], "input_B");

if( inputs.size() > 2 ) {
a_zero_point = inputs[2];
register_input(inputs[2], "a_zero_point");
/* There is no backend reference test for this case */
if( a_zero_point->data_dim[0] != 1 )
if( inputs[2]->data_dim[0] != 1 )
ERROR("Unimplemented: 1D zero_point input");
}
if( inputs.size() > 3 ) {
b_zero_point = inputs[3];
if( b_zero_point->data_dim[0] != 1 )
register_input(inputs[3], "b_zero_point");
if( inputs[3]->data_dim[0] != 1 )
ERROR("Unimplemented: 1D zero_point input");
}

Expand All @@ -137,8 +114,7 @@ class MatMulInteger : public Node {
rv->data_type = onnx::TensorProto_DataType_INT8;
else
rv->data_type = onnx::TensorProto_DataType_INT32;
Y=rv;
outputs.push_back(rv);
register_output(rv, "output_Y");
}

void result_dim( const std::vector< Tensor*> &inputs, int32_t &rows, int32_t &cols) const
Expand Down

0 comments on commit a167f31

Please sign in to comment.