Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dist attribute for mutable attribute. #62897

Merged
merged 2 commits into from
Mar 26, 2024

Conversation

winter-wang
Copy link
Contributor

@winter-wang winter-wang commented Mar 20, 2024

PR types

New features

PR changes

Others

Description

  • 可变attribue适配分布式场景。
    • 当op的其它输入为分布式tensor时,会给可变attribute对应的输入补充分布式属性。本pr合入后,InferMeta逻辑修改如下:
      // FullLikeOp::InferMeta函数实现,包含了分布式InferSpmd逻辑
      std::vector<pir::Type> FullLikeOp::InferMeta(const std::vector<pir::Value>& input_values, pir::AttributeMap& attributes) {
        // 其它单机代码省略
        // 该部分为分布式专属逻辑
        if(HasDistInput(input_values)) {
          auto ctx = pir::IrContext::Instance();
          ProcessMeshAttribute op_mesh;
      
          // 根据存在的分布式输入,获取一个默认ProcessMesh.
          for(auto value : input_values) {
            if (auto dist_interface = value.type().dyn_cast<DistTypeInterface>()) {
              op_mesh = dist_interface.process_mesh_attr();
              break;
            }
          }
          // 依次对可变Attribute补充分布式属性。其中,mesh为前文获取的默认ProcessMesh, dims_mapping为全-1.
          // value为可变Attribtue,给其定义op补充分布式属性。
          if(!value.FromTensor()) {
            auto dist_type = DistDenseTensorType::get(ctx, value_.type().dyn_cast<DenseTensorType>(), op_mesh);
            value_.set_type(dist_type);
            value_.defining_op()->set_attribute(kAttrOpDistAttr, OperationDistAttribute::get(ctx, op_mesh, {dist_type.tensor_dist_attr()},{}));
          }
          if(!AllInputAreDist(input_values)) {
              PADDLE_THROW(common::errors::Unimplemented(
                  "Current not supported mixed distributed inputs."));
          }
          std::vector<TensorDistAttribute> operand_dist_attrs, result_dist_attrs;
          auto dist_meta_x = CvtToDistMetaTensor(x_.type().dyn_cast<DistDenseTensorType>());
          auto spmd_info = InferSpmd(dist_meta_x, value, dtype);
          for(auto& arg_dist : spmd_info.first) {
              operand_dist_attrs.push_back(CvtToPirDistAttr(arg_dist));
          }
      
          auto dist_attr_out = CvtToPirDistAttr(spmd_info.second[0]);
          result_dist_attrs.push_back(dist_attr_out);
          argument_outputs.push_back(DistDenseTensorType::get(ctx, out_type.dyn_cast<pir::DenseTensorType>(), dist_attr_out));
      
          attributes[kAttrOpDistAttr] = OperationDistAttribute::get(
              ctx,
              op_mesh,
              operand_dist_attrs,
              result_dist_attrs
          );
          return argument_outputs;
        }
        //其它单机代码省略
      }
  • 自动微分适配分布式场景。
    • Program新增get_output_value_by_name的pybind接口,根据名字查找相应的变量。
    • shadow_output api增加算子的分布式属性的设置。
    • 分布式模型在train和eval模式下,当动转静得到loss对应的value时,将该value设置为隐藏的输出(通过shadow_output算子描述)。并在engine中新增loss_name成员变量,记录这些隐藏输出的名字。在pass结束后,通过loss_name找到新program中的同名隐藏输出。以它作为loss,进行自动微分。
    • 在本pr中的test_to_static_pir_program单测中,组网得到的program经append_backward以后,program如下:
      {
       (%0) = "builtin.parameter" () {is_persistable:[true],op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},result(0):{dims_maping:[0,-1]}},parameter_name:"parameter_1",stop_gradient:[false]} : () -> pd_dist.tensor<16x8xf32, mesh_shape:[2],dims_mappings:[0,-1]>
       (%1) = "builtin.parameter" () {is_persistable:[true],op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},result(0):{dims_maping:[-1,0]}},parameter_name:"parameter_0",stop_gradient:[false]} : () -> pd_dist.tensor<16x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>
       (%2) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"input0",op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},result(0):{dims_maping:[-1,-1]}},place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,16],stop_gradient:[true]} : () -> pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%3) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"label0",op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},result(0):{dims_maping:[-1,-1]}},place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,8],stop_gradient:[true]} : () -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%4) = "pd_op.relu" (%2) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]}},stop_gradient:[true]} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%5) = "pd_op.matmul" (%4, %1) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[-1,0]},result(0):{dims_maping:[-1,0]}},stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<16x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>) -> pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>
       (%6) = "pd_op.relu" (%5) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,0]},result(0):{dims_maping:[-1,0]}},stop_gradient:[false]} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>) -> pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>
       (%7) = "pd_op.matmul" (%6, %0) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,0]},operand(1):{dims_maping:[0,-1]},result(0):{dims_maping:[-1,-1],partial(0,SUM)}},stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>, pd_dist.tensor<16x8xf32, mesh_shape:[2],dims_mappings:[0,-1]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1], partial(0,SUM)>
       (%8) = "pd_op.relu" (%7) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]}},stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1], partial(0,SUM)>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%9) = "pd_op.subtract" (%8, %3) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]}},stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%10) = "pd_op.square" (%9) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]}},stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%11) = "pd_op.mean" (%10) {axis:(pd_op.IntArray)[],keepdim:false,op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},result(0):{dims_maping:[]}},stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<f32, mesh_shape:[2],dims_mappings:[]>
        () = "builtin.shadow_output" (%11) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[]}},output_name:"loss_0"} : (pd_dist.tensor<f32, mesh_shape:[2],dims_mappings:[]>) -> 
       (%12) = "pd_op.full" () {dtype:(pd_op.DataType)float32,op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1]}},place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)1} : () -> pd_dist.tensor<1xf32, mesh_shape:[2],dims_mappings:[-1]> 
       (%13) = "pd_op.full_like" (%11, %12) {dtype:(pd_op.DataType)float32,op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[]},result(0):{dims_maping:[]}},place:(pd_op.Place)Place(undefined:0),stop_gradient:[false]} : (pd_dist.tensor<f32, mesh_shape:[2],dims_mappings:[]>, pd_dist.tensor<1xf32, mesh_shape:[2],dims_mappings:[-1]>) -> pd_dist.tensor<f32, mesh_shape:[2],dims_mappings:[]>
       (%14) = "pd_op.mean_grad" (%10, %13) {axis:(pd_op.IntArray)[],keepdim:false,op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[]},result(0):{dims_maping:[-1,-1]}},reduce_all:false,stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<f32, mesh_shape:[2],dims_mappings:[]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%15) = "pd_op.square_grad" (%9, %14) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]}},stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%16, %17) = "pd_op.subtract_grad" (%8, %3, %15) {axis:(Int32)-1,op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[-1,-1]},operand(2):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]},result(1):{dims_maping:[-1,-1]}},stop_gradient:[false,false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, <<NULL TYPE>>
       (%18) = "pd_op.relu_grad" (%8, %16) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,-1]}},stop_gradient:[false]} : (pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>
       (%19, %20) = "pd_op.matmul_grad" (%6, %0, %18) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,0]},operand(1):{dims_maping:[0,-1]},operand(2):{dims_maping:[-1,-1]},result(0):{dims_maping:[-1,0]},result(1):{dims_maping:[0,-1]}},stop_gradient:[false,false],transpose_x:false,transpose_y:false} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>, pd_dist.tensor<16x8xf32, mesh_shape:[2],dims_mappings:[0,-1]>, pd_dist.tensor<4x8xf32, mesh_shape:[2],dims_mappings:[-1,-1]>) -> pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>, pd_dist.tensor<16x8xf32, mesh_shape:[2],dims_mappings:[0,-1]>
       (%21) = "pd_op.relu_grad" (%6, %19) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,0]},operand(1):{dims_maping:[-1,0]},result(0):{dims_maping:[-1,0]}},stop_gradient:[false]} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>, pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>) -> pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>
       (%22, %23) = "pd_op.matmul_grad" (%4, %1, %21) {op_dist_attr:{mesh:{shape:[2],process_ids:[0,1]},operand(0):{dims_maping:[-1,-1]},operand(1):{dims_maping:[-1,0]},operand(2):{dims_maping:[-1,0]},result(0):{dims_maping:[-1,-1],partial(0,SUM)},result(1):{dims_maping:[-1,0]}},stop_gradient:[false,false],transpose_x:false,transpose_y:false} : (pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,-1]>, pd_dist.tensor<16x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>, pd_dist.tensor<4x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>) -> <<NULL TYPE>>, pd_dist.tensor<16x16xf32, mesh_shape:[2],dims_mappings:[-1,0]>}

Other

Pcard-67164

Copy link

paddle-bot bot commented Mar 20, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@winter-wang winter-wang force-pushed the ir_develop branch 3 times, most recently from d7d077b to 2aabf38 Compare March 21, 2024 09:09
@winter-wang winter-wang force-pushed the ir_develop branch 6 times, most recently from 67f0741 to 15b1c85 Compare March 25, 2024 11:52
pir::StrAttribute::get(IrContext::Instance(), name);
for (auto &op : block) {
if (op.isa<pir::ShadowOutputOp>()) {
if (op.attribute("output_name") == name_attr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to assert the name is unique

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very thanks, I will fix the problem in next pr.


# Step 1.2: pir backward
if mode != "predict" and self._loss:
loss = dist_program.get_output_value_by_name(self._loss_names[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what ablout multiple losses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current PIR backward api only supports single losses situation. Therefore, only the first loss is processed here.
I will add annotations here in the next pr.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@winter-wang winter-wang merged commit f211563 into PaddlePaddle:develop Mar 26, 2024
29 of 30 checks passed
Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

co63oc pushed a commit to co63oc/Paddle that referenced this pull request Mar 26, 2024
* add dist attribute for mutable attribute.

* support backward for distribute pir.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants