Skip to content

Commit

Permalink
squash the wrong commit
Browse files Browse the repository at this point in the history
rollback

add the type

Update fusion.cc
  • Loading branch information
zeroRains committed Nov 8, 2023
1 parent d1870a1 commit f9653a6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 461 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
func : FusionRepeatedFCReluInferMeta
kernel :
func : fusion_repeated_fc_relu
data_type: x
intermediate : relu_out

- op : fusion_seqconv_eltadd_relu
Expand Down Expand Up @@ -264,6 +265,7 @@
func : FusionSquaredMatSubInferMeta
kernel :
func : fusion_squared_mat_sub
data_type : x
intermediate : squared_x, squared_y, squared_xy

- op : fusion_transpose_flatten_concat
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,7 @@ void FusionRepeatedFCReluInferMeta(const MetaTensor& x,
}
out->set_dims({i_dims[0], w_dims[sz - 1][1]});
out->share_lod(x);
out->set_dtype(x.dtype());
}

void FusionSquaredMatSubInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -2295,9 +2296,13 @@ void FusionSquaredMatSubInferMeta(const MetaTensor& x,
x_dims[1],
y_dims[0]));
squared_x->set_dims(x_dims);
squared_x->set_dtype(x.dtype());
squared_y->set_dims(y_dims);
squared_y->set_dtype(x.dtype());
squared_xy->set_dims({x_dims[0], y_dims[1]});
squared_xy->set_dtype(x.dtype());
out->set_dims({x_dims[0], y_dims[1]});
out->set_dtype(x.dtype());
}

void FusionGRUInferMeta(const MetaTensor& x,
Expand Down
Loading

0 comments on commit f9653a6

Please sign in to comment.