Skip to content

Commit

Permalink
Move BroadcastTensors OP to phi (#40047)
Browse files Browse the repository at this point in the history
* Move BroadcastTensors OP to phi

* Remove mutable_data in impl

* Move BilinearTensorProductInferMeta to multiary.h/cc
  • Loading branch information
From00 authored Mar 2, 2022
1 parent 8492d3b commit 2a5590a
Show file tree
Hide file tree
Showing 15 changed files with 658 additions and 504 deletions.
99 changes: 10 additions & 89 deletions paddle/fluid/operators/broadcast_tensors_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/broadcast_tensors_op.h"

#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {
Expand All @@ -31,64 +27,6 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "broadcast_tensors");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"broadcast_tensors");

int target_rank = 0;
const auto& input_dims = ctx->GetInputsDim("X");

// 1. Find Output rank = max(Inputs rank)
for (const auto& input_ddim : input_dims) {
target_rank = std::max(target_rank, input_ddim.size());
}

PADDLE_ENFORCE_GT(
target_rank, 0,
platform::errors::InvalidArgument(
"BroadcastTensorsOp requires at least one input tensor"
"to have rank greater than zero"));

std::vector<int64_t> target_dims(target_rank, 0);
// 2. Output dim(axis=x) = max(Inputs dim(axis=x))
for (int index = 0; index < target_rank; index++) {
// Loop axes in reverse order,
// For each axis, take the maximum as target size
// Fill size = 1 if shape vector exhausts
int target_dim_size = 1;
for (const auto& input_ddim : input_dims) {
// Reversed order
int axis = static_cast<int>(input_ddim.size()) - index - 1;
int dim_size = 1;
if (axis >= 0) {
dim_size = input_ddim[axis];
}

if (target_dim_size != 1 && dim_size != 1 &&
target_dim_size != dim_size) {
PADDLE_THROW(platform::errors::InvalidArgument(
"BroadcastTensorsOp inputs does not satisfy bcast semantics,"
"Please check axis = %d in reverse order",
index));
}

// We performed bcast semantics check at python level
// So input tensors should all have legal shape
target_dim_size = std::max(target_dim_size, dim_size);
}
target_dims[target_rank - index - 1] = target_dim_size;
}

// 3. Set Output Dim
std::vector<DDim> output_ddims;
for (size_t i = 0; i < input_dims.size(); i++) {
output_ddims.emplace_back(phi::make_ddim(target_dims));
}
ctx->SetOutputsDim("Out", output_ddims);
ctx->ShareAllLoD("X", /*->*/ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -229,34 +167,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(BroadcastTensorsGradNoNeedBufVarsInferer,
namespace ops = paddle::operators;
namespace plat = paddle::platform;

DELCARE_INFER_SHAPE_FUNCTOR(broadcast_tensors,
BroadcastTensorsInferShapeFunctor,
PT_INFER_META(phi::BroadcastTensorsInferMeta));

REGISTER_OPERATOR(broadcast_tensors, ops::BroadcastTensorsOp,
ops::BroadcastTensorsOpMaker,
ops::BroadcastTensorsGradOpMaker<paddle::framework::OpDesc>,
ops::BroadcastTensorsGradOpMaker<paddle::imperative::OpBase>,
ops::BroadcastTensorsOpVarTypeInference);
ops::BroadcastTensorsOpVarTypeInference,
BroadcastTensorsInferShapeFunctor);

REGISTER_OPERATOR(broadcast_tensors_grad, ops::BroadcastTensorsGradOp,
ops::BroadcastTensorsGradOpVarTypeInference,
ops::BroadcastTensorsGradNoNeedBufVarsInferer);

REGISTER_OP_CPU_KERNEL(
broadcast_tensors,
ops::BroadcastTensorsOpKernel<paddle::platform::CPUDeviceContext,
plat::float16>,
ops::BroadcastTensorsOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BroadcastTensorsOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::BroadcastTensorsOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::BroadcastTensorsOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::BroadcastTensorsOpKernel<paddle::platform::CPUDeviceContext, int64_t>);

REGISTER_OP_CPU_KERNEL(
broadcast_tensors_grad,
ops::BroadcastTensorsGradOpKernel<paddle::platform::CPUDeviceContext,
plat::float16>,
ops::BroadcastTensorsGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::BroadcastTensorsGradOpKernel<paddle::platform::CPUDeviceContext,
double>,
ops::BroadcastTensorsGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::BroadcastTensorsGradOpKernel<paddle::platform::CPUDeviceContext,
int64_t>);
122 changes: 0 additions & 122 deletions paddle/fluid/operators/broadcast_tensors_op.cu

This file was deleted.

Loading

0 comments on commit 2a5590a

Please sign in to comment.