diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc index 2ab6636840e0ec..7904627cf7fb7d 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc @@ -25,7 +25,7 @@ ElementwiseSPMDRule::InferForward( const std::vector& input_specs, const paddle::framework::AttributeMap& attrs) { // step0: Verify Input Args Based on Elementwise Logic - int64_t ninputs = static_cast(input_specs.size()); + int64_t ninputs = input_specs.size(); PADDLE_ENFORCE_GT( ninputs, 0, @@ -39,7 +39,7 @@ ElementwiseSPMDRule::InferForward( std::vector input_axes_vec; int64_t max_ndim = 0; for (int64_t i = 0; i < ninputs; ++i) { - int64_t ndim = static_cast(input_specs[i].shape().size()); + int64_t ndim = input_specs[i].shape().size(); if (ndim > max_ndim) { max_ndim = ndim; } @@ -49,7 +49,7 @@ ElementwiseSPMDRule::InferForward( std::vector broadcast_axis_count(max_ndim, 0); for (int64_t i = 0; i < ninputs; ++i) { std::vector shape = input_specs[i].shape(); - int64_t ndim = static_cast(shape.size()); + int64_t ndim = shape.size(); int64_t start_dim = max_ndim - ndim; std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet); if (ninputs > 1) { @@ -108,8 +108,8 @@ ElementwiseSPMDRule::InferForward( new_input_dist_attrs.emplace_back(dist_attr); } - // step2.4: handle partial - // Step2.3.2 handle input tensor partial (TODO) + // step3: handle partial + // handle input tensor partial (TODO) VLOG(4) << "ElementwiseSPMDRule InferForward:"; for (int64_t i = 0; i < ninputs; i++) { VLOG(4) << "Input" << std::to_string(i) << " shape: [" @@ -127,12 +127,85 @@ ElementwiseSPMDRule::InferForward( std::pair, std::vector> ElementwiseSPMDRule::InferBackward( + const std::vector& input_specs, const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of ElementwiseSPMDRule is NOT implemented yet.")); + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + int64_t noutputs = output_specs.size(); + PADDLE_ENFORCE_GT( + ninputs, + 0, + phi::errors::InvalidArgument("The size of InputSpec in elementwise must " + "be greater than 0, but got [%d].", + ninputs)); + PADDLE_ENFORCE_EQ( + noutputs, + 1, + phi::errors::InvalidArgument("The size of OutputSpec in elementwise must " + "be equal to 1, but got [%d].", + noutputs)); + VerifySpecs(output_specs, "elementwise_backward"); + + // step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::vector input_axes_vec; + int64_t output_ndim = output_specs[0].shape().size(); + std::string output_axes = + GetBroadcastAxes(output_ndim, output_ndim, alphabet); + + // get einsum notation for each input, deal with broadcast + for (int64_t i = 0; i < ninputs; ++i) { + const std::vector& shape = input_specs[i].shape(); + int64_t ndim = shape.size(); + int64_t start_dim = output_ndim - ndim; + std::string axes_notation = GetBroadcastAxes(ndim, output_ndim, alphabet); + if (ninputs > 1) { + for (int64_t idim = 0; idim < output_ndim; idim++) { + // deal with the broadcast axes + if (idim >= start_dim && shape[idim - start_dim] == 1) { + // mark the broadcast axis to a special "1" + axes_notation[idim - start_dim] = '1'; + } + } + } + input_axes_vec.emplace_back(axes_notation); + } + + // step2: Sharding Propogation + // step2.1: get dim mapping for each output axis + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}}); + + // step2.2: infer input dims mappings from output dims mapping + // and get the input distributed attributes to return + std::vector input_dist_attrs; + std::vector output_dist_attrs; + for (int64_t i = 0; i < ninputs; ++i) { + const DistTensorSpec& spec = input_specs[i]; + TensorDistAttr dist_attr(spec.dist_attr()); + std::vector dims_mapping = + GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map); + dist_attr.set_dims_mapping(dims_mapping); + input_dist_attrs.emplace_back(dist_attr); + } + + output_dist_attrs.emplace_back(output_specs[0].dist_attr()); + + // step3: handle partial (TODO) + + VLOG(4) << "ElementwiseSPMDRule InferBackward:"; + VLOG(4) << "Output shape: [" << str_join(output_specs[0].shape()) + << "] dims_mapping: [" << str_join(output_specs[0].dims_mapping()) + << "]"; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "dims_mapping: [" << str_join(input_dist_attrs[i].dims_mapping()) + << "]"; + } - return {}; + return {input_dist_attrs, output_dist_attrs}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h index 113c34e4f43ab9..ed01d23252b21d 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h @@ -32,7 +32,8 @@ class ElementwiseSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/test/auto_parallel/spmd_rules/test_elementwise_rule.py b/test/auto_parallel/spmd_rules/test_elementwise_rule.py index 34e3194410cc18..59a121c4bf0b3f 100644 --- a/test/auto_parallel/spmd_rules/test_elementwise_rule.py +++ b/test/auto_parallel/spmd_rules/test_elementwise_rule.py @@ -40,6 +40,8 @@ def setUp(self): y_tensor_dist_attr.process_mesh = process_mesh self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr) + self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec) + self.attrs = {} def test_single_mesh_dim(self): @@ -87,7 +89,7 @@ def test_single_mesh_dim(self): self.x_dist_tensor_spec.set_dims_mapping([-1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec], self.attrs ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -309,6 +311,253 @@ def test_multi_mesh_dim_broadcast(self): infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] ) + def test_backward_single_mesh_dim(self): + # [0, -1] --> [0, -1], [0, -1], [0, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # [-1, -1] --> [-1, -1], [-1, -1], [-1, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, -1]) + + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # [-1, 0]--> [-1, 0], [-1, 0] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, 0]) + + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + + def test_backward_single_mesh_dim_broadcast(self): + self.x_dist_tensor_spec.shape = [64, 36, 12] + self.y_dist_tensor_spec.shape = [12] + self.out_dist_tensor_spec.shape = [64, 36, 12] + + # [0, -1, -1] --> [0, -1, -1], [-1], [0, -1, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(len(resulted_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + + # [-1, 0, -1] --> [-1, 0, -1], [-1], [-1, 0, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1]) + self.assertEqual((infered_input_dist_attrs[1].dims_mapping), [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1]) + + # [-1, -1, 0] --> [-1, -1, 0], [0], [-1, -1, 0] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) + self.assertEqual((infered_input_dist_attrs[1].dims_mapping), [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + + self.x_dist_tensor_spec.shape = [64, 36, 12] + self.y_dist_tensor_spec.shape = [1, 12] + self.out_dist_tensor_spec.shape = [64, 36, 12] + # [-1, 0, -1] --> [-1, 0, -1], [-1, -1], [-1, 0, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1]) + + self.x_dist_tensor_spec.shape = [64, 1, 1, 12] + self.y_dist_tensor_spec.shape = [64, 32, 12] + self.out_dist_tensor_spec.shape = [64, 64, 32, 12] + # [0, -1, -1, -1] --> [0, -1, -1, -1], [-1, -1, -1], [0, -1, -1, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + + # [-1, 0, -1, -1] --> [-1, -1, -1, -1], [0, -1, -1], [-1, 0, -1, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -0, -1, -1] + ) + + def test_backward_multi_mesh_dim(self): + process_mesh = auto.ProcessMesh([[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.y_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 24, 48] + self.y_dist_tensor_spec.shape = [96, 24, 48] + self.out_dist_tensor_spec.shape = [96, 24, 48] + + # [0, 1, -1] --> [0, 1, -1], [0, 1, -1], [0, 1, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(len(resulted_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + def test_backward_multi_mesh_dim_broadcast(self): + process_mesh = auto.ProcessMesh([[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.y_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 24, 48] + self.y_dist_tensor_spec.shape = [48] + self.out_dist_tensor_spec.shape = [96, 24, 48] + + # [0, -1, 1] --> [0, -1, 1], [1], [0, -1, 1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(len(resulted_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # [0, 1, -1] --> [0, 1, -1], [-1], [0, 1, -1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + self.x_dist_tensor_spec.shape = [96, 1, 1, 48] + self.y_dist_tensor_spec.shape = [96, 24, 48] + self.out_dist_tensor_spec.shape = [96, 96, 24, 48] + + # [-1, 0, -1, 1] --> [-1, -1, -1, 1], [0, -1, 1], [-1, 0, -1, 1] (output --> inputs, output) + self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + + resulted_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 1] + ) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + if __name__ == "__main__": unittest.main()