Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Move GetExpectedPtenKernelArgs function into pten for infrt #38825

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Jan 9, 2022

PR types

Function optimization

PR changes

Others

Describe

[PTen] Move GetExpectedPtenKernelArgs function into pten for infrt

infrt在执行时,会直接载入program解析,并根据program的信息去选择kernel,但不会依赖paddle的op体系,因此,用于Op参数和kernel参数的映射函数GetExpectedPtenKernelArgs需要作为pten的组件,可以由框架灵活使用,并且不反向依赖fluid的实现

  1. 对fluid的适配,见代码中修改的示例
  • TODO:适配上,可以直接放到执行体系中,而不是在每个op中都重写GetExpectedPtenKernelArgs方法,鉴于一次性修改涉及文件较多,会在后续PR逐步修改
  1. 对infrt的适配,可以通过继承创建合适的ArgumentMappingContext,从而在infrt中调用相应匹配函数
// 伪代码
class ProtoArgumentMappingContext : public pten::ArgumentMappingContext {
 public:
  ProtoArgumentMappingContext(proto::OpProto* op, proto::BlockDesc* block) : op_(op), block_(block) {}
  bool HasInput(const std::string& name) const override {
    // simple search
    for (int i = 0; i < proto_->input_size(); ++i) {
      auto& in = proto_->inputs()[i];
      if (in.name() == name) {
        return true;
      }
    }
    return false;
  }
  bool HasOutput(const std::string& name) const override {
    // simple search
    for (int i = 0; i < proto_->output_size(); ++i) {
      auto& out = proto_->outputs()[i];
      if (out.name() == name) {
        return true;
      }
    }
    return false;
  }
  bool HasAttr(const std::string& name) const override {
    // simple search
    for (int i = 0; i < proto_->attrs_size(); ++i) {
      auto& attr = proto_->attrs()[i];
      if (attr.name() == name) {
        return true;
      }
    }
    return false;
  }
  size_t InputSize(const std::string& name) const override {
    return proto_->input_size();
  }
  size_t OutputSize(const std::string& name) const override {
    return proto_->output_size();
  }
  bool IsDenseTensorInput(const std::string& name) const override {
    for (int i = 0; i < block_.vars_size(); ++i) {
      auto& var = block_.vars()[i];
      if (var.name() == name) {
        if (var.type() == proto::VarType::LOD_TENSOR) {
          return true;
        }
      }
    }
    // TODO(chenweihang): throw error when cannot found
    return false;
  }
  bool IsSelectedRowsInput(const std::string& name) const override {
    for (int i = 0; i < block_.vars_size(); ++i) {
      auto& var = block_.vars()[i];
      if (var.name() == name) {
        if (var.type() == proto::VarType::SELECTED_ROWS) {
          return true;
        }
      }
    }
    // TODO(chenweihang): throw error when cannot found
    return false;
  }
 private:
  proto::OpProto op_*;
  proto::BlockDesc block_*;
};

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jan 9, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.


namespace pten {

KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉如果这个函数也能注册到一个Map里,原来的Op可能就不用再单独写GetExpectedPtenKernelArgs了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,下一个PR完成

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit 3a23c1a into PaddlePaddle:develop Jan 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants