Skip to content

Commit

Permalink
add elementwise backward rule (PaddlePaddle#56506)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc authored and BeingGod committed Sep 9, 2023
1 parent a5cce95 commit 07cd41e
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ElementwiseSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_GT(
ninputs,
0,
Expand All @@ -39,7 +39,7 @@ ElementwiseSPMDRule::InferForward(
std::vector<std::string> input_axes_vec;
int64_t max_ndim = 0;
for (int64_t i = 0; i < ninputs; ++i) {
int64_t ndim = static_cast<int64_t>(input_specs[i].shape().size());
int64_t ndim = input_specs[i].shape().size();
if (ndim > max_ndim) {
max_ndim = ndim;
}
Expand All @@ -49,7 +49,7 @@ ElementwiseSPMDRule::InferForward(
std::vector<int64_t> broadcast_axis_count(max_ndim, 0);
for (int64_t i = 0; i < ninputs; ++i) {
std::vector<int64_t> shape = input_specs[i].shape();
int64_t ndim = static_cast<int64_t>(shape.size());
int64_t ndim = shape.size();
int64_t start_dim = max_ndim - ndim;
std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet);
if (ninputs > 1) {
Expand Down Expand Up @@ -108,8 +108,8 @@ ElementwiseSPMDRule::InferForward(
new_input_dist_attrs.emplace_back(dist_attr);
}

// step2.4: handle partial
// Step2.3.2 handle input tensor partial (TODO)
// step3: handle partial
// handle input tensor partial (TODO)
VLOG(4) << "ElementwiseSPMDRule InferForward:";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
Expand All @@ -127,12 +127,85 @@ ElementwiseSPMDRule::InferForward(

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ElementwiseSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ElementwiseSPMDRule is NOT implemented yet."));
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_GT(
ninputs,
0,
phi::errors::InvalidArgument("The size of InputSpec in elementwise must "
"be greater than 0, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in elementwise must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "elementwise_backward");

// step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::vector<std::string> input_axes_vec;
int64_t output_ndim = output_specs[0].shape().size();
std::string output_axes =
GetBroadcastAxes(output_ndim, output_ndim, alphabet);

// get einsum notation for each input, deal with broadcast
for (int64_t i = 0; i < ninputs; ++i) {
const std::vector<int64_t>& shape = input_specs[i].shape();
int64_t ndim = shape.size();
int64_t start_dim = output_ndim - ndim;
std::string axes_notation = GetBroadcastAxes(ndim, output_ndim, alphabet);
if (ninputs > 1) {
for (int64_t idim = 0; idim < output_ndim; idim++) {
// deal with the broadcast axes
if (idim >= start_dim && shape[idim - start_dim] == 1) {
// mark the broadcast axis to a special "1"
axes_notation[idim - start_dim] = '1';
}
}
}
input_axes_vec.emplace_back(axes_notation);
}

// step2: Sharding Propogation
// step2.1: get dim mapping for each output axis
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}});

// step2.2: infer input dims mappings from output dims mapping
// and get the input distributed attributes to return
std::vector<TensorDistAttr> input_dist_attrs;
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < ninputs; ++i) {
const DistTensorSpec& spec = input_specs[i];
TensorDistAttr dist_attr(spec.dist_attr());
std::vector<int64_t> dims_mapping =
GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map);
dist_attr.set_dims_mapping(dims_mapping);
input_dist_attrs.emplace_back(dist_attr);
}

output_dist_attrs.emplace_back(output_specs[0].dist_attr());

// step3: handle partial (TODO)

VLOG(4) << "ElementwiseSPMDRule InferBackward:";
VLOG(4) << "Output shape: [" << str_join(output_specs[0].shape())
<< "] dims_mapping: [" << str_join(output_specs[0].dims_mapping())
<< "]";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "dims_mapping: [" << str_join(input_dist_attrs[i].dims_mapping())
<< "]";
}

return {};
return {input_dist_attrs, output_dist_attrs};
}

} // namespace auto_parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ElementwiseSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
Expand Down
Loading

0 comments on commit 07cd41e

Please sign in to comment.