Skip to content

Commit

Permalink
Add C++ EinsumOp which support 2 operands einsum. (#42105) (#42357)
Browse files Browse the repository at this point in the history
* full api fix

* when out is None, go old dygraph mode

* by static check

* first version: support 2-inputs forwards. TODO: 1. backward  2. BroadCast  3. MultiVariable

* time out -> 120
  • Loading branch information
2742195759 authored Apr 28, 2022
1 parent c8b6654 commit d04a68d
Show file tree
Hide file tree
Showing 13 changed files with 1,213 additions and 0 deletions.
95 changes: 95 additions & 0 deletions paddle/fluid/operators/einsum_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

namespace paddle {
namespace operators {
class EinsumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};

class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Operands", "(TensorList), The input tensor of einsum op.")
.AsDuplicable();
AddOutput("Out", "(Tensor), The output tensor of einsum op.");
AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in "
"equation must equals the `Operands` length.");
AddComment(R"DOC(
Einsum Operator.
This operator is used to perform einsum operation for given operands and equation.
)DOC");
}
};

class EinsumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name));
ctx->ShareAllLoD(x_name, x_grad_name);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

template <typename T>
class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"),
this->InputGrad("Operands", false));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferShape));

REGISTER_OPERATOR(einsum, ops::EinsumOp, ops::EinsumOpMaker,
EinsumInferShapeFunctor,
ops::EinsumGradMaker<paddle::framework::OpDesc>,
ops::EinsumGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(einsum_grad, ops::EinsumGradOp);
22 changes: 22 additions & 0 deletions paddle/phi/kernels/cpu/einsum_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/einsum_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h"

PD_REGISTER_KERNEL(
einsum_grad, CPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {}
20 changes: 20 additions & 0 deletions paddle/phi/kernels/cpu/einsum_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/einsum_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {}
28 changes: 28 additions & 0 deletions paddle/phi/kernels/einsum_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const DenseTensor& out_grad,
const std::string& equation,
std::vector<DenseTensor*> x_grad);

} // namespace phi
27 changes: 27 additions & 0 deletions paddle/phi/kernels/einsum_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out);

} // namespace phi
22 changes: 22 additions & 0 deletions paddle/phi/kernels/gpu/einsum_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/einsum_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h"

PD_REGISTER_KERNEL(
einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {}
21 changes: 21 additions & 0 deletions paddle/phi/kernels/gpu/einsum_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/einsum_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {}
Loading

0 comments on commit d04a68d

Please sign in to comment.