Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Update matmul spmd rule name #57176

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@
output : Tensor
infer_meta :
func : MatmulInferMeta
spmd_rule : MatmulSpmdInferForward
spmd_rule : MatmulInferSpmd
kernel :
func : matmul
backward : matmul_grad
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/infermeta/spmd_rules/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ void FillMatmulOperandNotation(const int x_ndim,

////////////////// InferMeta(Contains SPMD) Functions //////////////////

SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x,
const DistMetaTensor& y,
bool trans_x,
bool trans_y) {
SpmdInfo MatmulInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
bool trans_x,
bool trans_y) {
// Step0: verify input args based on matmul logic
auto x_shape = phi::vectorize(x.dims());
auto y_shape = phi::vectorize(y.dims());
Expand Down Expand Up @@ -221,11 +221,11 @@ SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x,
return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}};
}

SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out,
bool trans_x,
bool trans_y) {
SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out,
bool trans_x,
bool trans_y) {
auto out_shape = phi::vectorize(out.dims());
int out_ndim = out_shape.size();

Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/infermeta/spmd_rules/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ limitations under the License. */
namespace phi {
namespace distributed {

SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x,
SpmdInfo MatmulInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
bool trans_x,
bool trans_y);

SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out,
bool trans_x,
bool trans_y);

SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out,
bool trans_x,
bool trans_y);

} // namespace distributed
} // namespace phi
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ namespace distributed {

// matmul rule
PD_REGISTER_SPMD_RULE(matmul,
PD_INFER_SPMD(phi::distributed::MatmulSpmdInferForward),
PD_INFER_SPMD(phi::distributed::MatmulSpmdInferBackward));
PD_INFER_SPMD(phi::distributed::MatmulInferSpmd),
PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse));

} // namespace distributed
} // namespace phi