diff --git a/src/nodes/matmulinteger.h b/src/nodes/matmulinteger.h index 4617207..92f237a 100644 --- a/src/nodes/matmulinteger.h +++ b/src/nodes/matmulinteger.h @@ -49,6 +49,7 @@ class MatMulInteger : public Node { { std::string intype = A->data_type_str(); std::string outtype = Y->data_type_str(); + std::string weighttype = B->data_type_str(); std::string a_zero; std::string b_zero; @@ -78,7 +79,7 @@ class MatMulInteger : public Node { INDT_1 "/*MatMulInteger*/" << std::endl; INDT_1 << intype << " *A = (" << intype << "*)" << A->cname() << ";" << std::endl; - INDT_1 << intype << " *B = (" << intype << "*)" << B->cname() << ";" << std::endl; + INDT_1 << weighttype << " *B = (" << weighttype << "*)" << B->cname() << ";" << std::endl; INDT_1 << outtype << " *Y = (" << outtype << "*)" << Y->cname() << ";" << std::endl; INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;