Skip to content

Commit

Permalink
Bert Support (PaddlePaddle#223)
Browse files Browse the repository at this point in the history
* lookup_table resolve
  • Loading branch information
yaozhixin committed Oct 19, 2021
1 parent 27610ff commit fcd219e
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 46 deletions.
64 changes: 64 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
});
} else {
std::vector<int64_t> new_shape_{label_shape_[0]};
// workaround for bert
if (GetInputVarNode("Label", node)->Name() == "next_sentence_labels") {
new_shape_ = {label_shape_[0], label_shape_[1]};
}
auto const_before_loss = CreateBaseOp(
graph, node, "popart_constant", {}, {},
{{"value", new_shape_},
Expand Down Expand Up @@ -248,6 +252,64 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
}
}

Node *cumsum_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive"));
int64_t popart_exclusive = 1 ? exclusive : 0;
auto reverse = BOOST_GET_CONST(bool, op->GetAttr("reverse"));
int64_t popart_reverse = 1 ? reverse : 0;
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto axis_node =
CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{axis}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
return CreateBaseOp(
graph, node, "popart_cumsum",
{GetInputVarNode("X", node), axis_node->outputs[0]},
{GetOutputVarNode("Out", node)},
{{"exclusive", popart_exclusive}, {"reverse", popart_reverse}});
}

Node *matmul_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape();

std::vector<int64_t> perm;
int x_rank = x_shape.size();
if (x_rank == 1) {
perm = std::vector<int64_t>{0};
} else if (x_rank == 2) {
perm = std::vector<int64_t>{1, 0};
} else if (x_rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank));
}

Node *x_node = GetInputVarNode("X", node);
Node *y_node = GetInputVarNode("Y", node);

if (transpose_x) {
x_node = CreateBaseOp(graph, node, "popart_transpose",
{GetInputVarNode("X", node)}, {}, {{"perm", perm}});
x_node = x_node->outputs[0];
}
if (transpose_y) {
y_node = CreateBaseOp(graph, node, "popart_transpose",
{GetInputVarNode("Y", node)}, {}, {{"perm", perm}});
y_node = y_node->outputs[0];
}

return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node},
node->outputs);
}

REGISTER_HANDLER(mean, mean_handler);
REGISTER_HANDLER(pow, pow_handler);
REGISTER_HANDLER(mul, mul_handler);
Expand All @@ -256,6 +318,8 @@ REGISTER_HANDLER(sum, sum_handler);
REGISTER_HANDLER(softmax, softmax_handler);
REGISTER_HANDLER(scale, scale_handler);
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(cumsum, cumsum_handler);
REGISTER_HANDLER(matmul_v2, matmul_v2_handler);

} // namespace
} // namespace ipu
Expand Down
70 changes: 64 additions & 6 deletions paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,18 @@ Node *lookup_table_handler(Graph *graph, Node *node) {
w_node = GetInputVarNode("W", node);
}

auto squeeze = CreateBaseOp(graph, node, "popart_squeeze",
{GetInputVarNode("Ids", node)}, {},
{{"axes", std::vector<int64_t>{-1}}});
// support lookup_table_v2
auto ids = GetInputVarNode("Ids", node);
auto ids_shape = ids->Var()->GetShape();
if (ids_shape[ids_shape.size() - 1] == 1) {
ids = CreateBaseOp(graph, node, "popart_squeeze",
{GetInputVarNode("Ids", node)}, {},
{{"axes", std::vector<int64_t>{-1}}});
ids = ids->outputs[0];
}

auto gather =
CreateBaseOp(graph, node, "popart_gather", {w_node, squeeze->outputs[0]},
{GetOutputVarNode("Out", node)}, {});
auto gather = CreateBaseOp(graph, node, "popart_gather", {w_node, ids},
{GetOutputVarNode("Out", node)}, {});
return gather;
}

Expand Down Expand Up @@ -370,6 +375,56 @@ Node *expand_handler(Graph *graph, Node *node) {
return new_node;
}

Node *assign_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_identity",
{GetInputVarNode("X", node)},
{GetOutputVarNode("Out", node)}, {});
}

Node *fill_any_like_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto dtype = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto x_dtype = static_cast<proto::VarType::Type>(dtype);
size_t size = 1;
for (auto &dim : x_shape) {
size *= dim;
}

Attribute out_value;
switch (x_dtype) {
case proto::VarType::FP32:
out_value = std::vector<float>(size, value);
break;
case proto::VarType::FP64:
out_value = std::vector<double>(size, value);
break;
case proto::VarType::INT32:
out_value = std::vector<int>(size, value);
break;
case proto::VarType::INT64:
out_value = std::vector<int64_t>(size, value);
break;
case proto::VarType::BOOL:
out_value = std::vector<int64_t>(size, value);
break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("fill_any_like dtype: %d", x_dtype));
}
return CreateConst(graph, node, node->inputs, node->outputs,
AttributeMap{
{"value", out_value},
{"dims", x_shape},
{"dtype", VarType2OnnxDtype(dtype)},
});
}

Node *lookup_table_v2_handler(Graph *graph, Node *node) {
return lookup_table_handler(graph, node);
}

REGISTER_HANDLER(fill_constant, fill_constant_handler);
REGISTER_HANDLER(gaussian_random, gaussian_random_handler);
REGISTER_HANDLER(uniform_random, uniform_random_handler);
Expand All @@ -385,6 +440,9 @@ REGISTER_HANDLER(stack, stack_handler);
REGISTER_HANDLER(shape, shape_handler);
REGISTER_HANDLER(slice, slice_handler);
REGISTER_HANDLER(expand, expand_handler);
REGISTER_HANDLER(assign, assign_handler);
REGISTER_HANDLER(fill_any_like, fill_any_like_handler);
REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler);

} // namespace
} // namespace ipu
Expand Down
72 changes: 39 additions & 33 deletions paddle/fluid/framework/ir/ipu/infer_shape_pass.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

#include "paddle/fluid/framework/ir/ipu/infer_shape_pass.h"

#include "paddle/fluid/framework/ipu/ipu_backend.h"

#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
Expand All @@ -32,9 +31,9 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << DebugString(graph);

// Make batch_size fixed
bool need_infer_shape = false;
std::shared_ptr<ipu::IpuBackend> ipu_backend = ipu::IpuBackend::GetInstance();
auto batch_size = ipu_backend->GetIpuStrategy()->batch_size;

auto feed_list = Get<std::vector<std::string>>("feed_list");
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
Expand All @@ -47,6 +46,7 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
if (input_shape[0] <= -1) {
input_shape[0] = batch_size;
node->Var()->SetShape(input_shape);
need_infer_shape = true;
}
// int64->int32
if (node->Var()->GetDataType() == proto::VarType::INT64) {
Expand All @@ -56,44 +56,50 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
}

// temp scope for shape inference
std::shared_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope());
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
auto var_desc = node->Var();
auto* ptr = scope->Var(var_desc->Name());
paddle::framework::InitializeVariable(ptr, var_desc->GetType());
if (need_infer_shape) {
std::shared_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope());
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
auto var_desc = node->Var();
auto* ptr = scope->Var(var_desc->Name());
paddle::framework::InitializeVariable(ptr, var_desc->GetType());

auto tensor = ptr->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape()));
}
auto tensor = ptr->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape()));
}

// infer shape
auto nodes = ir::TopologySortOperations(*graph);
for (auto node : nodes) {
auto op_desc = node->Op();
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), *scope);
op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx);
// infer shape
auto nodes = ir::TopologySortOperations(*graph);
for (auto node : nodes) {
VLOG(10) << "InferShapePass: Infer shape for Op (" << node->Name() << ")";
auto op_desc = node->Op();
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(),
*scope);
op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx);

for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) {
for (int i = 0; i < it->second.size(); i++) {
auto output_name = op_desc->Output(it->first)[i];
auto dim =
it->second[i]->GetMutable<paddle::framework::LoDTensor>()->dims();
auto new_shape = paddle::framework::vectorize(dim);
for (auto output_node : node->outputs) {
if (output_node->Name() == output_name) {
output_node->Var()->SetShape(new_shape);
for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) {
for (int i = 0; i < it->second.size(); i++) {
auto output_name = op_desc->Output(it->first)[i];
auto dim =
it->second[i]->GetMutable<paddle::framework::LoDTensor>()->dims();
auto new_shape = paddle::framework::vectorize(dim);
for (auto output_node : node->outputs) {
if (output_node->Name() == output_name) {
output_node->Var()->SetShape(new_shape);
}
}
}
}
VLOG(10) << "InferShapePass: Infer shape for Op (" << node->Name()
<< ") finished";
}
// release the temp scope
scope.reset();
}
// release the temp scope
scope.reset();

VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
}
}

auto lr_var_name = node->Op()->Input("LearningRate");
PADDLE_ENFORCE_EQ(lr_var_name.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.",
node->Op()->Type()));

ipu_backend->GetExecutor().SetLRVarName(lr_var_name[0]);
if (node->Op()->HasInput("LearningRate")) {
auto lr_var_name = node->Op()->Input("LearningRate");
PADDLE_ENFORCE_EQ(lr_var_name.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.",
node->Op()->Type()));
ipu_backend->GetExecutor().SetLRVarName(lr_var_name[0]);
}
}

if ((op_role == static_cast<int>(framework::OpRole::kLoss))) {
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/fluid/framework.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
'cuda_places',
'cpu_places',
'xpu_places',
'ipu_places',
'cuda_pinned_places',
'in_dygraph_mode',
'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_xpu',
'is_compiled_with_ipu',
'Variable',
'require_version',
'device_guard',
Expand Down Expand Up @@ -411,6 +413,21 @@ def is_compiled_with_xpu():
return core.is_compiled_with_xpu()


def is_compiled_with_ipu():
"""
Whether this whl package can be used to run the model on IPU.
Returns (bool): support ipu or not.
Examples:
.. code-block:: python
import paddle.fluid as fluid
support_ipu = fluid.is_compiled_with_ipu()
"""
return core.is_compiled_with_ipu()


def is_compiled_with_cuda():
"""
Whether this whl package can be used to run the model on GPU.
Expand Down Expand Up @@ -527,6 +544,26 @@ def xpu_places(device_ids=None):
return [core.XPUPlace(dev_id) for dev_id in device_ids]


def ipu_places():
"""
This function creates a list of :code:`paddle.IPUPlace` objects, and returns the created list.
Returns:
list of paddle.IPUPlace: Created IPU place list.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
ipu_places = static.ipu_places()
"""
assert core.is_compiled_with_ipu(), \
"Not compiled with IPU"
return [core.IPUPlace()]


def cpu_places(device_count=None):
"""
This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list.
Expand Down
Loading

0 comments on commit fcd219e

Please sign in to comment.