Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pir infermeta func support inferspmd. #62659

Merged
merged 1 commit into from
Mar 15, 2024

Conversation

winter-wang
Copy link
Contributor

@winter-wang winter-wang commented Mar 12, 2024

PR types

New features

PR changes

Others

Description

pir infermeta func support inferspmd.

  • pir的op的InferMeta函数里面自动调用InferSpmd逻辑。

    • 当前生成的MatmulOp的InferMeta函数代码如下:

      // MatmulOp新增内联静态函数InferSpmd    
      class MatmulOp : public pir::Op<MatmulOp,......> {
         public:
          // ...........其它代码省略
          static phi::distributed::SpmdInfo InferSpmd(const phi::distributed::DistMetaTensor& x, const phi::distributed::DistMetaTensor& y, bool transpose_x, bool transpose_y) {
            return phi::distributed::MatmulInferSpmd(x, y, transpose_x, transpose_y);
          }
        
          static void InferMeta( phi::InferMetaContext *infer_meta );
          static std::vector<pir::Type> InferMeta( const std::vector<pir::Value>& input_values, pir::AttributeMap& attributes );
        };
      
      
        // MatmulOp::InferMeta函数实现,新增了分布式逻辑
        std::vector<pir::Type> MatmulOp::InferMeta(const std::vector<pir::Value>& input_values, pir::AttributeMap& attributes) {
        
          IR_ENFORCE(input_values.size() == 2,
              "Num of inputs is expected to be 2 but got %d.", input_values.size());
        
          pir::Value x_ = input_values[0]; (void)x_;
          pir::Value y_ = input_values[1]; (void)y_;
          VLOG(4) << "Builder construction outputs";
          bool is_from_tensor = false; (void) is_from_tensor;
        
          paddle::dialect::DenseTensorType x;
          if (x_.type().isa<paddle::dialect::DenseTensorType>()) {
            x = x_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void)x;
          } else {
            PADDLE_THROW(phi::errors::Unimplemented("Only support paddle::dialect::DenseTensorType or paddle::dialect::AllocatedDenseTensorType"));
          }
        
          paddle::dialect::DenseTensorType y;
          if (y_.type().isa<paddle::dialect::DenseTensorType>()) {
            y = y_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void)y;
          } else {
            PADDLE_THROW(phi::errors::Unimplemented("Only support paddle::dialect::DenseTensorType or paddle::dialect::AllocatedDenseTensorType"));
          }
        
        
          IR_ENFORCE(
              attributes.find("transpose_x") != attributes.end(),
                  "'transpose_x' Attribute is expected for MatmulOp. ");
          bool transpose_x = attributes.at("transpose_x").dyn_cast<pir::BoolAttribute>().data();
        
          IR_ENFORCE(
              attributes.find("transpose_y") != attributes.end(),
                  "'transpose_y' Attribute is expected for MatmulOp. ");
          bool transpose_y = attributes.at("transpose_y").dyn_cast<pir::BoolAttribute>().data();
        
        
          VLOG(4) << "Builder construction  dense_x";
          paddle::dialect::IrTensor ir_tensor_x(paddle::dialect::TransToPhiDataType(x.dtype()),
                                                              x.dims(),
                                                              x.data_layout(),
                                                              x.lod(),
                                                              x.offset());
          VLOG(4) << "Builder construction  meta_x";
          paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);
        
          VLOG(4) << "Builder construction  dense_y";
          paddle::dialect::IrTensor ir_tensor_y(paddle::dialect::TransToPhiDataType(y.dtype()),
                                                              y.dims(),
                                                              y.data_layout(),
                                                              y.lod(),
                                                              y.offset());
          VLOG(4) << "Builder construction  meta_y";
          paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y);
          paddle::dialect::IrTensor dense_out;
          paddle::dialect::IrMetaTensor meta_out(&dense_out);
        
          phi::MatmulInferMeta(meta_x, meta_y, transpose_x, transpose_y, &meta_out);
        
          std::vector<pir::Type> argument_outputs;
          pir::Type out_type = CvtToDenseTensorType(dense_out);
        
          // 该部分为新增的分布式InferSpmd逻辑,仅在分布式条件下会生成该部分代码
          if(!input_values.empty() && AllInputAreDist(input_values)) {
            ProcessMeshAttribute op_mesh = input_values[0].type().dyn_cast<DistDenseTensorType>().process_mesh_attr();
            std::vector<TensorDistAttribute> operand_dist_attrs, result_dist_attrs;
            auto dist_meta_x = CvtToDistMetaTensor(x_.type().dyn_cast<DistDenseTensorType>());
            auto dist_meta_y = CvtToDistMetaTensor(y_.type().dyn_cast<DistDenseTensorType>());
            auto spmd_info = InferSpmd(dist_meta_x, dist_meta_y, transpose_x, transpose_y);
            for(auto& arg_dist : spmd_info.first) {
                operand_dist_attrs.push_back(CvtToPirDistAttr(arg_dist));
            }
        
            auto dist_attr_out = CvtToPirDistAttr(spmd_info.second[0]);
            result_dist_attrs.push_back(dist_attr_out);
            argument_outputs.push_back(DistDenseTensorType::get(pir::IrContext::Instance(), out_type.dyn_cast<pir::DenseTensorType>(), dist_attr_out));
        
            attributes[kAttrOpDistAttrs] = OperationDistAttribute::get(
                pir::IrContext::Instance(),
                op_mesh,
                operand_dist_attrs,
                result_dist_attrs
            );
            return argument_outputs;
          }
        
          argument_outputs.push_back(out_type);
        
          return argument_outputs;
        }
  • 将pir的InferMeta的函数签名中const AttributeMap&改为AttributeMap&。 这是因为在InferMeta中的Inferpmd逻辑中,会给AttributeMap中插入分布式属性。

Other

Pcard-67164

Copy link

paddle-bot bot commented Mar 12, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@winter-wang winter-wang force-pushed the ir_develop branch 19 times, most recently from 878fd8c to 36c989a Compare March 14, 2024 09:21
zhangbo9674
zhangbo9674 previously approved these changes Mar 14, 2024
zhiqiu
zhiqiu previously approved these changes Mar 14, 2024
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@winter-wang winter-wang dismissed stale reviews from zhiqiu and zhangbo9674 via c326337 March 14, 2024 16:19
@winter-wang winter-wang force-pushed the ir_develop branch 2 times, most recently from c326337 to 24409c5 Compare March 14, 2024 16:34
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@winter-wang winter-wang merged commit feb5d69 into PaddlePaddle:develop Mar 15, 2024
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants