Skip to content

Commit

Permalink
support ernie model (PaddlePaddle#96)
Browse files Browse the repository at this point in the history
* update ernie model

* use inplace=False

* matmul workaround

* fix matmul

* rollback matmul

* chaneg CreateBaseOp api

* workaround forward_graph and update cross_entropy

* fix SetIpuIndexStage

* use virtualGraph/pipelineStage

Co-authored-by: XBWGC <xiaobingw@graphcore.ai>
Co-authored-by: yaozhixin <522190855@qq.com>
  • Loading branch information
3 people authored Aug 27, 2021
1 parent 1a4fd5d commit 325ffa5
Show file tree
Hide file tree
Showing 16 changed files with 442 additions and 286 deletions.
32 changes: 22 additions & 10 deletions paddle/fluid/framework/ipu/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,43 @@ void Compiler::InsertTensors(std::vector<std::string> output_names,
std::pair<std::string, std::string>(output_names[0], tensor_id));
}

void Compiler::SetIpuIndexStage(const std::vector<std::string> &tensor_ids,
const OpDesc *op_desc) {
void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) {
// TODO(xiaobingw): replace ipu_index with macro or constexpr
VLOG(10) << "enter Compiler::SetIpuIndexStage";
auto tensor_ids_set =
std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
if (op_desc->HasAttr("ipu_index")) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr("ipu_index"));
for (const auto &tensor_id : tensor_ids) {
builder_->virtualGraph(tensor_id, ipu_index);
}
builder_->virtualGraph(tensor_ids_set, ipu_index);
VLOG(10) << "set ipu_index= " << ipu_index
<< " for op: " << op_desc->Type();
if (op_desc->HasAttr("ipu_stage")) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr("ipu_stage"));
for (const auto &tensor_id : tensor_ids) {
builder_->pipelineStage(tensor_id, ipu_stage);
}
builder_->pipelineStage(tensor_ids_set, ipu_stage);
VLOG(10) << "set ipu_stage= " << ipu_stage
<< " for op: " << op_desc->Type();
}
}
VLOG(10) << "leave Compiler::SetIpuIndexStage";
}

void Compiler::SetIpuIndexStage(const std::string &tensor_id,
const OpDesc *op_desc) {
void Compiler::SetIpuIndexStage(const std::string& tensor_id,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetIpuIndexStage";
if (op_desc->HasAttr("ipu_index")) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr("ipu_index"));
builder_->virtualGraph(tensor_id, ipu_index);
VLOG(10) << "set ipu_index= " << ipu_index
<< " for op: " << op_desc->Type();
if (op_desc->HasAttr("ipu_stage")) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr("ipu_stage"));
builder_->pipelineStage(tensor_id, ipu_stage);
VLOG(10) << "set ipu_stage= " << ipu_stage
<< " for op: " << op_desc->Type();
}
}
VLOG(10) << "leave Compiler::SetIpuIndexStage";
}

template <typename T>
Expand Down Expand Up @@ -219,6 +229,7 @@ void Compiler::RegisterOpFunc() {
name_function_.emplace("popart_batchnormalization", BatchNormHandler);
name_function_.emplace("popart_constant", Constant);
name_function_.emplace("popart_nllloss", NllLoss);
name_function_.emplace("popart_groupnormalization", Groupnormalization);
}

void Compiler::LowerBody(const ir::Graph* graph) {
Expand All @@ -241,6 +252,7 @@ void Compiler::LowerBody(const ir::Graph* graph) {
auto func = name_function_[op->Type()];
func(node->Op());
}
VLOG(10) << "leave Compiler::LowerBody";
}

} // namespace ipu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace {
ir::Node *activation_op_handler(ir::Graph *graph, ir::Node *node,
const std::string &type) {
auto new_node =
CreateBaseOp(graph, type, {GetInputNode("X", node)}, node->outputs);
CreateBaseOp(graph, node, type, {GetInputNode("X", node)}, node->outputs);
return new_node;
}

Expand Down Expand Up @@ -54,10 +54,8 @@ ir::Node *gelu_handler(ir::Graph *graph, ir::Node *node) {

ir::Node *log_softmax_handler(ir::Graph *graph, ir::Node *node) {
auto axis_ = BOOST_GET_CONST(int, node->Op()->GetAttr("axis"));
return CreateBaseOp(graph, "popart_logsoftmax", node->inputs, node->outputs,
{
{"axis", int64_t{axis_}},
});
return CreateBaseOp(graph, node, "popart_logsoftmax", node->inputs,
node->outputs, {{"axis", int64_t{axis_}}});
}

REGISTER_HANDLER(relu, relu_handler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
return;
}
if (op->HasAttr(attr_name)) {
VLOG(10) << "Copying attr: " << attr_name << " from " << op->Type()
<< " to " << new_op->Type();
new_op->SetAttr(attr_name, op->GetAttr(attr_name));
new_op->Flush();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ir::Node *elementwise_op_handler(ir::Graph *graph, ir::Node *node,
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
if (axis == -1 || axis == x_rank - 1 || x_rank == y_rank) {
auto new_node = CreateBaseOp(
graph, type, {GetInputNode("X", node), GetInputNode("Y", node)},
graph, node, type, {GetInputNode("X", node), GetInputNode("Y", node)},
node->outputs);
return new_node;
} else {
Expand All @@ -46,15 +46,15 @@ ir::Node *elementwise_op_handler(ir::Graph *graph, ir::Node *node,
{"dtype", ONNXDataType::INT64},
};
// constant
auto new_node_const = CreateConst(graph, {}, {}, attrs);
auto new_node_const = CreateConst(graph, node, {}, {}, attrs);
// reshape
auto new_node_reshape =
CreateBaseOp(graph, "popart_reshape",
CreateBaseOp(graph, node, "popart_reshape",
{GetInputNode("Y", node), new_node_const->outputs[0]}, {});
// elementwise_op
auto new_node = CreateBaseOp(
graph, type, {GetInputNode("X", node), new_node_reshape->outputs[0]},
node->outputs);
graph, node, type,
{GetInputNode("X", node), new_node_reshape->outputs[0]}, node->outputs);
return new_node;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace {

ir::Node *equal_handler(ir::Graph *graph, ir::Node *node) {
auto new_node = CreateBaseOp(
graph, "popart_equal", {GetInputNode("X", node), GetInputNode("Y", node)},
node->outputs);
graph, node, "popart_equal",
{GetInputNode("X", node), GetInputNode("Y", node)}, node->outputs);
return new_node;
}

Expand Down
161 changes: 119 additions & 42 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
auto keepdims_ = BOOST_GET_CONST(bool, op->GetAttr("keep_dim"));
auto keepdims = int64_t{keepdims_};
attrs.emplace("keepdims", keepdims);
return CreateBaseOp(graph, "popart_reducemean", node->inputs, node->outputs,
attrs);
return CreateBaseOp(graph, node, "popart_reducemean", node->inputs,
node->outputs, attrs);
}

ir::Node *mean_handler(ir::Graph *graph, ir::Node *node) {
return CreateBaseOp(graph, "popart_reducemean", {GetInputNode("X", node)},
{GetOutputNode("Out", node)},
return CreateBaseOp(graph, node, "popart_reducemean",
{GetInputNode("X", node)}, {GetOutputNode("Out", node)},
{
{"keepdims", int64_t{0}},
});
Expand All @@ -51,8 +51,8 @@ ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
auto attrs =
MakeConstAttrMapFromValue<float>(value_, {1}, ONNXDataType::FLOAT);
auto new_node_const = CreateConst(graph, {}, {}, attrs);
return CreateBaseOp(graph, "popart_pow",
auto new_node_const = CreateConst(graph, node, {}, {}, attrs);
return CreateBaseOp(graph, node, "popart_pow",
{GetInputNode("X", node), new_node_const->outputs[0]},
node->outputs);
}
Expand All @@ -73,46 +73,86 @@ ir::Node *mul_handler(ir::Graph *graph, ir::Node *node) {
reshape_shape_.push_back(int64_t(y_shape_[right]));
}
auto x_flatten =
CreateBaseOp(graph, "popart_flatten", {GetInputNode("X", node)}, {},
CreateBaseOp(graph, node, "popart_flatten", {GetInputNode("X", node)}, {},
{{"axis", int64_t(x_num_col_dims)}});
auto y_flatten =
CreateBaseOp(graph, "popart_flatten", {GetInputNode("Y", node)}, {},
CreateBaseOp(graph, node, "popart_flatten", {GetInputNode("Y", node)}, {},
{{"axis", int64_t(y_num_col_dims)}});
auto matmul =
CreateBaseOp(graph, "popart_matmul",
CreateBaseOp(graph, node, "popart_matmul",
{x_flatten->outputs[0], y_flatten->outputs[0]}, {}, {});

auto reshape_const = CreateConst(
graph, {}, {},
graph, node, {}, {},
{{"value", reshape_shape_},
{"dims", std::vector<int64_t>{int64_t(reshape_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
return CreateBaseOp(graph, "popart_reshape",
{matmul->outputs[0], reshape_const->outputs[0]},
node->outputs, {});
return CreateBaseOp(graph, node, "popart_reshape",
{matmul->outputs[0], reshape_const->outputs[0]},
node->outputs, {});
}

ir::Node *matmul_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y"));
auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha"));
return CreateGemm(graph, node->inputs, node->outputs, transpose_x,
transpose_y, alpha);
auto x_shape = GetInputNodeShape("X", node);
auto y_shape = GetInputNodeShape("Y", node);
int x_rank = x_shape.size();
std::vector<int64_t> perm;
if (x_rank == 2) {
return CreateGemm(graph, node,
{GetInputNode("X", node), GetInputNode("Y", node)},
node->outputs, transpose_x, transpose_y, alpha);
} else if (x_rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank));
}

Node *x_node = GetInputNode("X", node);
Node *y_node = GetInputNode("Y", node);
if (transpose_x) {
x_node = CreateBaseOp(graph, node, "popart_transpose",
{GetInputNode("X", node)}, {}, {{"perm", perm}});
x_node = x_node->outputs[0];
}
if (transpose_y) {
y_node = CreateBaseOp(graph, node, "popart_transpose",
{GetInputNode("Y", node)}, {}, {{"perm", perm}});
y_node = y_node->outputs[0];
}
// TODO(alleng) move 1e-8 to global or create a funtion like equal()
if (abs(alpha - 1.0) > 1e-8) {
auto o_node =
CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {});
auto attr = MakeConstAttrMapFromValue(alpha, {1}, ONNXDataType::FLOAT);
auto const_node = CreateConst(graph, node, {}, {}, attr);
return CreateBaseOp(graph, node, "popart_mul",
{o_node->outputs[0], const_node->outputs[0]},
node->outputs);
} else {
return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node},
node->outputs);
}
}

ir::Node *sum_handler(ir::Graph *graph, ir::Node *node) {
return CreateBaseOp(graph, "popart_sum", node->inputs, node->outputs);
return CreateBaseOp(graph, node, "popart_sum", node->inputs, node->outputs);
}

ir::Node *softmax_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto axis_ = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto axis = int64_t{axis_};
return CreateBaseOp(graph, "popart_softmax", node->inputs, node->outputs,
{
{"axis", axis},
});
return CreateBaseOp(graph, node, "popart_softmax", node->inputs,
node->outputs, {
{"axis", axis},
});
}

ir::Node *scale_handler(ir::Graph *graph, ir::Node *node) {
Expand All @@ -125,55 +165,92 @@ ir::Node *scale_handler(ir::Graph *graph, ir::Node *node) {

// TODO(yaozhixin): support tensor as scale input
if (abs(scale_ - 1.0) < 1e-06 && abs(bias_ - 0.0) < 1e-06) {
auto new_node_identity = CreateBaseOp(
graph, "popart_identity", {GetInputNode("X", node)}, node->outputs, {});
auto new_node_identity =
CreateBaseOp(graph, node, "popart_identity", {GetInputNode("X", node)},
node->outputs, {});
return new_node_identity;
} else {
auto new_node_bias =
CreateConst(graph, {}, {}, {{"value", std::vector<float>{bias_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{bias_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
auto new_node_scale =
CreateConst(graph, {}, {}, {{"value", std::vector<float>{scale_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{scale_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
// convert to float32
auto new_node_cast = CreateCast(graph, {GetInputNode("X", node)}, {},
auto new_node_cast = CreateCast(graph, node, {GetInputNode("X", node)}, {},
static_cast<int>(proto::VarType::FP32));

ir::Node *result = nullptr;
if (bias_after_scale_) {
auto new_node_mul = CreateBaseOp(
graph, "popart_mul",
graph, node, "popart_mul",
{new_node_cast->outputs[0], new_node_scale->outputs[0]}, {}, {});
result = CreateBaseOp(
graph, "popart_add",
graph, node, "popart_add",
{new_node_mul->outputs[0], new_node_bias->outputs[0]}, {}, {});
} else {
auto new_node_add = CreateBaseOp(
graph, "popart_add",
graph, node, "popart_add",
{new_node_cast->outputs[0], new_node_bias->outputs[0]}, {}, {});
result = CreateBaseOp(
graph, "popart_mul",
graph, node, "popart_mul",
{new_node_add->outputs[0], new_node_scale->outputs[0]}, {}, {});
}
auto result_after_cast = CreateCast(graph, result->outputs, node->outputs,
static_cast<int>(data_type_));
auto result_after_cast =
CreateCast(graph, node, result->outputs, node->outputs,
static_cast<int>(data_type_));
return result_after_cast;
}
}

ir::Node *cross_entropy2_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
auto new_cast = CreateCast(graph, {GetInputNode("Label", node)}, {},
auto new_cast = CreateCast(graph, node, {GetInputNode("Label", node)}, {},
proto::VarType::INT32);
return CreateBaseOp(graph, "popart_nllloss",
{GetInputNode("X", node), new_cast->outputs[0]},
{GetOutputNode("Y", node)},
{
{"ignoreIndex", ignoreIndex},
});
auto label_shape_ = op->Block()->FindVar(op->Input("Label")[0])->GetShape();
if (label_shape_.size() == 1) {
return CreateBaseOp(graph, node, "popart_nllloss",
{GetInputNode("X", node), new_cast->outputs[0]},
{GetOutputNode("Y", node)},
{
{"ignoreIndex", ignoreIndex},
});
} else {
std::vector<int64_t> new_shape_{label_shape_[0]};
auto const_before_loss = CreateBaseOp(
graph, node, "popart_constant", {}, {},
{{"value", new_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}},
{"dtype", ONNXDataType::INT64}});

auto reshape_before_loss = CreateBaseOp(
graph, node, "popart_reshape",
{new_cast->outputs[0], const_before_loss->outputs[0]}, {}, {});

auto nllloss = CreateBaseOp(
graph, node, "popart_nllloss",
{GetInputNode("X", node), reshape_before_loss->outputs[0]}, {},
{
{"ignoreIndex", ignoreIndex},
});

auto const_after_loss = CreateBaseOp(
graph, node, "popart_constant", {}, {},
{{"value", label_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(label_shape_.size())}},
{"dtype", ONNXDataType::INT64}});

auto reshape_after_loss =
CreateBaseOp(graph, node, "popart_reshape",
{nllloss->outputs[0], const_after_loss->outputs[0]},
{GetOutputNode("Y", node)}, {});
return reshape_after_loss;
}
}

REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
Expand Down
Loading

0 comments on commit 325ffa5

Please sign in to comment.