Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] Add elementwise infer_backward rule #56506

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ElementwiseSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_GT(
ninputs,
0,
Expand All @@ -39,7 +39,7 @@ ElementwiseSPMDRule::InferForward(
std::vector<std::string> input_axes_vec;
int64_t max_ndim = 0;
for (int64_t i = 0; i < ninputs; ++i) {
int64_t ndim = static_cast<int64_t>(input_specs[i].shape().size());
int64_t ndim = input_specs[i].shape().size();
if (ndim > max_ndim) {
max_ndim = ndim;
}
Expand All @@ -49,7 +49,7 @@ ElementwiseSPMDRule::InferForward(
std::vector<int64_t> broadcast_axis_count(max_ndim, 0);
for (int64_t i = 0; i < ninputs; ++i) {
std::vector<int64_t> shape = input_specs[i].shape();
int64_t ndim = static_cast<int64_t>(shape.size());
int64_t ndim = shape.size();
int64_t start_dim = max_ndim - ndim;
std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet);
if (ninputs > 1) {
Expand Down Expand Up @@ -108,8 +108,8 @@ ElementwiseSPMDRule::InferForward(
new_input_dist_attrs.emplace_back(dist_attr);
}

// step2.4: handle partial
// Step2.3.2 handle input tensor partial (TODO)
// step3: handle partial
// handle input tensor partial (TODO)
VLOG(4) << "ElementwiseSPMDRule InferForward:";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
Expand All @@ -127,12 +127,85 @@ ElementwiseSPMDRule::InferForward(

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ElementwiseSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ElementwiseSPMDRule is NOT implemented yet."));
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_GT(
ninputs,
0,
phi::errors::InvalidArgument("The size of InputSpec in elementwise must "
"be greater than 0, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in elementwise must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "elementwise_backward");

// step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::vector<std::string> input_axes_vec;
int64_t output_ndim = output_specs[0].shape().size();
std::string output_axes =
GetBroadcastAxes(output_ndim, output_ndim, alphabet);

// get einsum notation for each input, deal with broadcast
for (int64_t i = 0; i < ninputs; ++i) {
const std::vector<int64_t>& shape = input_specs[i].shape();
int64_t ndim = shape.size();
int64_t start_dim = output_ndim - ndim;
std::string axes_notation = GetBroadcastAxes(ndim, output_ndim, alphabet);
if (ninputs > 1) {
for (int64_t idim = 0; idim < output_ndim; idim++) {
// deal with the broadcast axes
if (idim >= start_dim && shape[idim - start_dim] == 1) {
// mark the broadcast axis to a special "1"
axes_notation[idim - start_dim] = '1';
}
}
}
input_axes_vec.emplace_back(axes_notation);
}

// step2: Sharding Propogation
// step2.1: get dim mapping for each output axis
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}});

// step2.2: infer input dims mappings from output dims mapping
// and get the input distributed attributes to return
std::vector<TensorDistAttr> input_dist_attrs;
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < ninputs; ++i) {
const DistTensorSpec& spec = input_specs[i];
TensorDistAttr dist_attr(spec.dist_attr());
std::vector<int64_t> dims_mapping =
GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map);
dist_attr.set_dims_mapping(dims_mapping);
input_dist_attrs.emplace_back(dist_attr);
}

output_dist_attrs.emplace_back(output_specs[0].dist_attr());

// step3: handle partial (TODO)

VLOG(4) << "ElementwiseSPMDRule InferBackward:";
VLOG(4) << "Output shape: [" << str_join(output_specs[0].shape())
<< "] dims_mapping: [" << str_join(output_specs[0].dims_mapping())
<< "]";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "dims_mapping: [" << str_join(input_dist_attrs[i].dims_mapping())
<< "]";
}

return {};
return {input_dist_attrs, output_dist_attrs};
}

} // namespace auto_parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ElementwiseSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
Expand Down
Loading