Skip to content

Commit

Permalink
Convinteger: weights zero point
Browse files Browse the repository at this point in the history
  • Loading branch information
kraiskil committed May 19, 2024
1 parent 44fc920 commit d8ee76f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/nodes/convinteger.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ class ConvInteger : public SpatialFilter {
const std::string &w_idx,
const std::string &y_idx) const override
{
std::string x_zero;
std::string x_zero="0";
if( get_number_of_inputs() >= 3 ) // x_zero_point is optional, 3rd input
x_zero = constant_acces_code( "x_zero_point[0]");
else
x_zero = "0";

std::string w_zero="0";
if( get_number_of_inputs() >= 4 ) // w_zero_point is optional, 4th input
w_zero = constant_acces_code( "w_zero_point[0]");


INDT_4 << get_W()->data_type_str() << " w_ = " << constant_acces_code("w[m][c][k0][k1]") << ";" << std::endl;
std::string dest;
Expand All @@ -47,7 +50,7 @@ class ConvInteger : public SpatialFilter {
else
dest = "y[b][m][o0][o1]";

INDT_4 << dest << "+= (x[b][c][i0+k0][i1+k1] - " << x_zero << ") * w_;" << std::endl;
INDT_4 << dest << "+= (x[b][c][i0+k0][i1+k1] - " << x_zero << ") * (w_ -" << w_zero << ");" << std::endl;
}

virtual void print_output_cell_finalize(std::ostream &dst, const std::string &y_idx) const override
Expand Down Expand Up @@ -78,7 +81,6 @@ class ConvInteger : public SpatialFilter {
name_input(2, "x_zero_point");
if( get_number_of_inputs() > 3 ){
name_input(3, "w_zero_point");
ERROR("unimplemented: weight zero points");
}

if( get_X()->data_dim.size() != 4 )
Expand Down

0 comments on commit d8ee76f

Please sign in to comment.