Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ EinsumOp which support 2 operands einsum. #42105

Merged
merged 18 commits into from
Apr 26, 2022

Conversation

2742195759
Copy link
Contributor

@2742195759 2742195759 commented Apr 21, 2022

PR types

New features

PR changes

APIs

Describe

This PR developed C++ EinsumOp.

TODO:

  1. Split the InferMeta Function into unary.cc and unary.h
  2. Add more check and more informative error message in EinsumOp
  3. Support Multi-Inputs einsum in python: einsum() function. use the opt_einsum library to speed up.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<memory>,<unordered_map>,, einsum_impl.h这4个头文件是必需的么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前einsum impl中有InferMeta,后续会提PR将这个函数移动到 unary.cc 中,下个删除

void Make() override {
AddInput("Operands", "(Tensor), The input tensor of svd op.")
.AsDuplicable();
AddOutput("Out", "(Tensor), The output VH tensor of svd op.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AddOutput("Out", "(Tensor), The output VH tensor of svd op.");
AddOutput("Out", "(Tensor), The output tensor of einsum op.");

using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_x = "Operands";
auto out_x_g_n = framework::GradVarName(in_x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto out_x_g_n = framework::GradVarName(in_x);
auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name);

注意命名规范


#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/einsum_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "paddle/phi/kernels/einsum_kernel.h" 要放在第一行,且空行隔开


#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/einsum_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "paddle/phi/kernels/einsum_kernel.h" 要放在第一行,且空行隔开

} else {
PADDLE_ENFORCE_EQ((*labelshape)[c],
op_dim[dim_ptr],
phi::errors::InvalidArgument(""));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的报错信息需要完善下

PADDLE_ENFORCE_EQ(v == 1 || broadcast_dims->at(idx) == 1 ||
broadcast_dims->at(idx) == v,
true,
phi::errors::InvalidArgument("can't broad cast."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的报错信息需要完善下

const LabelMap& labelshape,
std::vector<int>* output_dims) {
for (int c : right) {
if (c == '.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议block用{} 括住

std::vector<int>* output_dims,
std::string* right) {
auto results = paddle::string::split_string(equation, "->");
auto left = results[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否需要先ENFORCE下 results的size是否符合预期?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个步骤在 ValidEinsum 函数中已经验证了。这里就没验证了。

VLOG(5) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
print_label(all_labels, labeltype);
print_label(all_labels, labelshape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个print放到VLOG下来触发吧,另外,op的vlog等级可以设置为3

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferShape));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是少提交了一个文件?

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_einsum_op PROPERTIES TIMEOUT 120)

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_einsum_op PROPERTIES TIMEOUT 120)

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

细节麻烦追加完善下

}
}

inline void EinsumInferShape(const std::vector<const MetaTensor*>& inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议移到infermeta子目录

python/paddle/tensor/einsum.py Show resolved Hide resolved
auto it = std::find(all_labels.begin(), all_labels.end(), c);
PADDLE_ENFORCE_NE(it,
all_labels.end(),
phi::errors::InvalidArgument("Must in all_labels."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息过短的问题后面也一并完善下

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit c7302f9 into PaddlePaddle:develop Apr 26, 2022
@2742195759 2742195759 deleted the einsum branch April 26, 2022 08:50
2742195759 added a commit to 2742195759/Paddle that referenced this pull request Apr 28, 2022
* full api fix

* when out is None, go old dygraph mode

* by static check

* first version: support 2-inputs forwards. TODO: 1. backward  2. BroadCast  3. MultiVariable

* time out -> 120
XiaoguangHu01 pushed a commit that referenced this pull request Apr 28, 2022
* full api fix

* when out is None, go old dygraph mode

* by static check

* first version: support 2-inputs forwards. TODO: 1. backward  2. BroadCast  3. MultiVariable

* time out -> 120
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants