Skip to content

Commit

Permalink
[NewIR]New ir using kernel registrer type (PaddlePaddle#56789)
Browse files Browse the repository at this point in the history
* update

* fix batch norm grad args def

* fix bug

* fix combine slice bug

* fix slice bug

* update builtin split

* disable using kernel resigter dtype

* polish code

* disable some test
  • Loading branch information
phlrain authored and BeingGod committed Sep 9, 2023
1 parent 4e0f41e commit 2b809df
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 122 deletions.
191 changes: 71 additions & 120 deletions paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,50 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in,
}
}

ir::Type BuildOutputType(ir::Type type,
const phi::Place& place,
phi::DataType data_type,
ir::IrContext* ctx) {
if (type.isa<dialect::DenseTensorType>()) {
auto dense_tensor_type = type.dyn_cast<dialect::DenseTensorType>();
auto out_dtype = dense_tensor_type.dtype();

// TODO(phlrain): open this after fix pr(55509) confict
// if (data_type != phi::DataType::UNDEFINED) {
// out_dtype = TransToIrDataType(data_type, ctx);
// }

return dialect::AllocatedDenseTensorType::get(
ctx,
place,
out_dtype,
dense_tensor_type.dims(),
dense_tensor_type.data_layout(),
dense_tensor_type.lod(),
dense_tensor_type.offset());

} else if (type.isa<dialect::SelectedRowsType>()) {
auto selected_rows_type = type.dyn_cast<dialect::SelectedRowsType>();
auto out_dtype = selected_rows_type.dtype();

// TODO(phlrain): open this after fix pr(55509) confict
// if (data_type != phi::DataType::UNDEFINED) {
// out_dtype = TransToIrDataType(data_type, ctx);
// }
return dialect::AllocatedSelectedRowsType::get(
ctx,
place,
out_dtype,
selected_rows_type.dims(),
selected_rows_type.data_layout(),
selected_rows_type.lod(),
selected_rows_type.offset());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"BuildOutputType only support DenseTensorType and SelectedRowsType"));
}
}

phi::DataType GetKernelDataTypeByYamlInfo(
const ir::Operation* op,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
Expand Down Expand Up @@ -525,6 +569,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
std::vector<phi::Place> out_places;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> vec_inner_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
Expand All @@ -540,6 +585,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
vec_inner_types.push_back(new_in.type());
if (new_in.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
out_places.push_back(
new_in.type()
Expand All @@ -553,49 +599,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
}
// Copy op output type
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<ir::VectorType>()) {
std::vector<ir::Type> vec_inner_types;
auto base_types = result_type.dyn_cast<ir::VectorType>().data();
for (size_t idx = 0; idx < base_types.size(); idx++) {
auto& base_type = base_types[idx];
if (base_type) {
if (base_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_places[idx],
base_type.dyn_cast<dialect::DenseTensorType>());
vec_inner_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor in vector type for now"));
}
} else {
// NOTE(phlrain), kernel not support a nullptr in output
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {{}};
size_t offset = 0;
auto dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
vec_inner_types.push_back(dense_tensor_dtype);
}
}
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"builtin.combine Result type only support "
"VectorType<DenseTensorType>"));
}
}
}
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
Expand All @@ -614,9 +620,8 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
}

if (op_item->name() == "builtin.slice") {
phi::Place out_place = place;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
Expand All @@ -635,39 +640,18 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,

if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
out_place =
vec_types[op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data()]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
auto index = op_item->attributes()
.at("index")
.dyn_cast<ir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"builtin.slice Result type only support DenseTensorType"));
}
}
}

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
Expand All @@ -689,6 +673,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
std::vector<phi::Place> out_places(op_item->num_results());
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
std::vector<ir::Type> op_output_types;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
Expand All @@ -708,37 +693,15 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
for (uint64_t idx = 0; idx < vec_types.size(); idx++) {
out_places[idx] =
vec_types[idx]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
op_output_types.push_back(vec_types[idx]);
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_places[i],
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"builtin.split Result type only support DenseTensorType"));
}
}
}

// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
Expand Down Expand Up @@ -805,36 +768,30 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
}

for (size_t i = 0; i < op_item->num_results(); ++i) {
phi::Place out_place;
phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend());

phi::DataType out_phi_dtype = phi::DataType::UNDEFINED;
if ((!UnchangeOutputOps.count(op_item->name())) &&
(!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) {
out_place = phi::TransToPhiPlace(output_defs[i].backend);
} else {
out_place = phi::TransToPhiPlace(kernel_key.backend());
out_phi_dtype = output_defs[i].dtype;
}

auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else if (result_type.isa<dialect::DenseTensorType>() ||
result_type.isa<dialect::SelectedRowsType>()) {
op_output_types.push_back(
BuildOutputType(result_type, out_place, out_phi_dtype, ctx));
} else if (result_type.isa<ir::VectorType>()) {
std::vector<ir::Type> vec_inner_types;
auto base_types = result_type.dyn_cast<ir::VectorType>().data();
for (auto& base_type : base_types) {
if (base_type) {
if (base_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
base_type.dyn_cast<dialect::DenseTensorType>());
vec_inner_types.push_back(allocated_dense_tensor_dtype);
vec_inner_types.push_back(
BuildOutputType(base_type, out_place, out_phi_dtype, ctx));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support dense tensor in vector type for now"));
Expand All @@ -857,16 +814,10 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,

ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
} else if (result_type.isa<dialect::SelectedRowsType>()) {
auto allocated_selected_rows_dtype =
paddle::dialect::AllocatedSelectedRowsType::get(
ctx,
out_place,
result_type.dyn_cast<dialect::SelectedRowsType>());
op_output_types.emplace_back(allocated_selected_rows_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Result type only support DenseTensorType and VectorType"));
"Result type only support DenseTensorType, SelectedRowType and "
"VectorType"));
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,6 @@ PD_REGISTER_KERNEL(batch_norm_grad,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
Expand All @@ -1405,7 +1404,6 @@ PD_REGISTER_KERNEL(batch_norm_grad,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
Expand Down
20 changes: 20 additions & 0 deletions test/ir/new_ir/test_standalone_new_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,26 @@ def func(x, y):
np.testing.assert_array_equal(z.numpy(), gold_res)


# TODO(phlrain): open this after fix pr(55509) confict
# class TestNewIrLogicalDygraph(unittest.TestCase):
# def test_with_new_ir(self):
# paddle.disable_static()

# @paddle.jit.to_static
# def func(x, y, z):
# a = paddle.logical_and(x, y)
# return z + a.cast("float32")

# x = paddle.ones([2, 2], dtype='float32')
# y = paddle.ones([2, 2], dtype='float32')
# z = paddle.ones([2, 2], dtype='float32')

# z = func(x, y, z)

# gold_res = np.ones([2, 2], dtype="float32") * 2
# np.testing.assert_array_equal(z.numpy(), gold_res)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit 2b809df

Please sign in to comment.