diff --git a/src/nodes/matmul.h b/src/nodes/matmul.h index c05c30d..2ea18d3 100644 --- a/src/nodes/matmul.h +++ b/src/nodes/matmul.h @@ -5,17 +5,12 @@ class MatMul : public Node { public: MatMul() { op_name = "MatMul"; - A=B=Y=NULL; } - // inputs - const Tensor *A; - const Tensor *B; - // outputs - const Tensor *Y; - virtual void print(std::ostream &dst) const override { + Tensor *A = inputs[0]; + Tensor *B = inputs[1]; std::string type = A->data_type_str(); if( A->data_dim.size() != 2 ) @@ -44,8 +39,8 @@ class MatMul : public Node { } virtual void resolve(void) override { - A = inputs[0]; - B = inputs[1]; + Tensor *A = inputs[0]; + Tensor *B = inputs[1]; register_input(A, "A"); register_input(B, "B"); if( typeConstraint_highPrecisionNumeric(A) == false ) @@ -61,7 +56,6 @@ class MatMul : public Node { rv->data_dim.push_back(rows); rv->data_dim.push_back(cols); rv->data_type = A->data_type; - Y=rv; register_output(rv, "Y"); } diff --git a/src/nodes/matmulinteger.h b/src/nodes/matmulinteger.h index 6e5a7a3..eb8cf60 100644 --- a/src/nodes/matmulinteger.h +++ b/src/nodes/matmulinteger.h @@ -15,38 +15,13 @@ class MatMulInteger : public Node { public: MatMulInteger() { op_name = "MatMulInteger"; - A=B=Y=NULL; - a_zero_point=b_zero_point=NULL; } - // inputs - const Tensor *A; - const Tensor *B; - // optional inputs - const Tensor *a_zero_point; - const Tensor *b_zero_point; - // outputs - const Tensor *Y; - - virtual void print_parameters(std::ostream &dst, bool decorate ) const override - { - A->print_tensor_as_const(dst, !decorate); - dst << ", "; - B->print_tensor_as_const(dst, !decorate); - dst << ", "; - if( a_zero_point ) { - a_zero_point->print_tensor_as_const(dst, !decorate); - dst << ", "; - } - if( b_zero_point ) { - b_zero_point->print_tensor_as_const(dst, !decorate); - dst << ", "; - } - Y->print_tensor(dst, !decorate); - } - virtual void print(std::ostream &dst) const override { + Tensor *A = inputs[0]; + Tensor *B = inputs[1]; + Tensor *Y = outputs[0]; std::string intype = A->data_type_str(); std::string outtype = Y->data_type_str(); std::string weighttype = B->data_type_str(); @@ -68,24 +43,27 @@ class MatMulInteger : public Node { if( inner != inner2 ) ERROR("MatMulInteger input's inner dimensions don't match"); - if( a_zero_point ) - a_zero = a_zero_point->cname() + "[0]"; + if( inputs.size() > 2) + a_zero = "a_zero_point[0]"; else a_zero = "0"; - if( b_zero_point ) - b_zero = b_zero_point->cname() + "[0]"; + if( inputs.size() > 3) + b_zero = "b_zero_point[0]"; else b_zero = "0"; INDT_1 "/*MatMulInteger*/" << std::endl; - INDT_1 << intype << " *A = (" << intype << "*)" << A->cname() << ";" << std::endl; - INDT_1 << weighttype << " *B = (" << weighttype << "*)" << B->cname() << ";" << std::endl; - INDT_1 << outtype << " *Y = (" << outtype << "*)" << Y->cname() << ";" << std::endl; + INDT_1 << intype << " *A = (" << intype << "*)input_A;" << std::endl; + INDT_1 << weighttype << " *B = (" << weighttype << "*)input_B;" << std::endl; + INDT_1 << outtype << " *Y = (" << outtype << "*)output_Y;" << std::endl; INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl; INDT_2 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl; + // NB: quantization here is the experimental ONNXC quantization + // that is not only integers, but also scales the output to 8bits. + // This quantization terribly kludgy, and really should be removed if( options.quantize ) INDT_3 << "int32_t sum = 0;" << std::endl; else @@ -106,23 +84,22 @@ class MatMulInteger : public Node { } INDT_2 "}" << std::endl; - } virtual void resolve(void) override { - A = inputs[0]; - B = inputs[1]; + register_input(inputs[0], "input_A"); + register_input(inputs[1], "input_B"); if( inputs.size() > 2 ) { - a_zero_point = inputs[2]; + register_input(inputs[2], "a_zero_point"); /* There is no backend reference test for this case */ - if( a_zero_point->data_dim[0] != 1 ) + if( inputs[2]->data_dim[0] != 1 ) ERROR("Unimplemented: 1D zero_point input"); } if( inputs.size() > 3 ) { - b_zero_point = inputs[3]; - if( b_zero_point->data_dim[0] != 1 ) + register_input(inputs[3], "b_zero_point"); + if( inputs[3]->data_dim[0] != 1 ) ERROR("Unimplemented: 1D zero_point input"); } @@ -137,8 +114,7 @@ class MatMulInteger : public Node { rv->data_type = onnx::TensorProto_DataType_INT8; else rv->data_type = onnx::TensorProto_DataType_INT32; - Y=rv; - outputs.push_back(rv); + register_output(rv, "output_Y"); } void result_dim( const std::vector< Tensor*> &inputs, int32_t &rows, int32_t &cols) const