Skip to content

Commit

Permalink
[Prim][PIR] Recover pir bn (#60689)
Browse files Browse the repository at this point in the history
* reopen bn prim pir

* fix atol

* decomp support batch_norm_

* fix test case

* fix bug

* fix  code
  • Loading branch information
cyber-pioneer authored Jan 12, 2024
1 parent fcb2137 commit 125a671
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
decomp_interface_declare_gen_op_list = [
"add_n",
"batch_norm",
"batch_norm_",
"dropout",
"full_like",
"gelu",
Expand Down
13 changes: 9 additions & 4 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,8 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
)
op_interfaces_tmp = op_interfaces
exclusive_interface_str_tmp = exclusive_interface_str
decomp_interface_str = "paddle::dialect::DecompInterface"
decomp_interface_declare_str = "\n static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op);"

# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
Expand Down Expand Up @@ -1212,10 +1214,13 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
and kernel_func_name in decomp_interface_declare_gen_op_list
and dialect_name != "onednn_op"
):
op_interfaces = op_interfaces + [
"paddle::dialect::DecompInterface"
]
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op);"
if decomp_interface_str not in op_interfaces:
op_interfaces = op_interfaces + [decomp_interface_str]
if (
decomp_interface_declare_str
not in exclusive_interface_str
):
exclusive_interface_str += decomp_interface_declare_str
else:
op_interfaces = op_interfaces_tmp
exclusive_interface_str = exclusive_interface_str_tmp
Expand Down
90 changes: 89 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::vector<std::vector<pir::OpResult>> BatchNormOp::Decomp(
float epsilon =
op->attribute("epsilon").dyn_cast<pir::FloatAttribute>().data();
const std::string& data_layout =
op->attribute("data_layout").dyn_cast<pir::StrAttribute>().AsString();
op->attribute("data_format").dyn_cast<pir::StrAttribute>().AsString();
bool use_global_stats =
op->attribute("use_global_stats").dyn_cast<pir::BoolAttribute>().data();
bool trainable_statistics = op->attribute("trainable_statistics")
Expand Down Expand Up @@ -117,5 +117,93 @@ std::vector<std::vector<pir::OpResult>> BatchNormOp::Decomp(
return res;
}

std::vector<std::vector<pir::OpResult>> BatchNorm_Op::Decomp(
pir::Operation* op) {
VLOG(4) << "Decomp call batch_norm_'s decomp interface begin";
BatchNorm_Op op_obj = op->dyn_cast<BatchNorm_Op>();
(void)op_obj;

FLAGS_tensor_operants_mode = "static";

VLOG(6) << "Decomp Prepare inputs of batch_norm_";

Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor mean(std::make_shared<primitive::LazyTensor>(op_obj.mean()));
Tensor variance(std::make_shared<primitive::LazyTensor>(op_obj.variance()));
paddle::optional<Tensor> scale;
if (!IsEmptyValue(op_obj.scale())) {
scale = paddle::make_optional<Tensor>(
Tensor(std::make_shared<primitive::LazyTensor>(op_obj.scale())));
}
paddle::optional<Tensor> bias;
if (!IsEmptyValue(op_obj.bias())) {
bias = paddle::make_optional<Tensor>(
Tensor(std::make_shared<primitive::LazyTensor>(op_obj.bias())));
}

VLOG(6) << "Decomp prepare attributes of batch_norm_";
bool is_test = op->attribute("is_test").dyn_cast<pir::BoolAttribute>().data();
float momentum =
op->attribute("momentum").dyn_cast<pir::FloatAttribute>().data();
float epsilon =
op->attribute("epsilon").dyn_cast<pir::FloatAttribute>().data();
const std::string& data_layout =
op->attribute("data_format").dyn_cast<pir::StrAttribute>().AsString();
bool use_global_stats =
op->attribute("use_global_stats").dyn_cast<pir::BoolAttribute>().data();
bool trainable_statistics = op->attribute("trainable_statistics")
.dyn_cast<pir::BoolAttribute>()
.data();

VLOG(6) << "Decomp call batch_norm_'s forward composite rule prepare";

auto org_res = op->results();
std::vector<std::vector<pir::OpResult>> res(org_res.size());

VLOG(6) << "Decomp call batch_norm_'s forward composite rule begin";

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> op_res =
paddle::primitive::details::batch_norm_decomp<primitive::LazyTensor>(
x,
mean,
variance,
scale,
bias,
is_test,
momentum,
epsilon,
data_layout,
use_global_stats,
trainable_statistics);

VLOG(6) << "Decomp call batch_norm_'s forward composite rule end";

res[0].push_back(std::static_pointer_cast<primitive::LazyTensor>(
std::get<0>(op_res).impl())
->value()
.dyn_cast<pir::OpResult>());
res[1].push_back(std::static_pointer_cast<primitive::LazyTensor>(
std::get<1>(op_res).impl())
->value()
.dyn_cast<pir::OpResult>());
res[2].push_back(std::static_pointer_cast<primitive::LazyTensor>(
std::get<2>(op_res).impl())
->value()
.dyn_cast<pir::OpResult>());
res[3].push_back(std::static_pointer_cast<primitive::LazyTensor>(
std::get<3>(op_res).impl())
->value()
.dyn_cast<pir::OpResult>());
res[4].push_back(std::static_pointer_cast<primitive::LazyTensor>(
std::get<4>(op_res).impl())
->value()
.dyn_cast<pir::OpResult>());
pir::OpResult reserve_space;
res[5].push_back(reserve_space);

VLOG(4) << "Decomp call batch_norm_'s decomp interface end";
return res;
}

} // namespace dialect
} // namespace paddle
12 changes: 8 additions & 4 deletions paddle/fluid/primitive/base/decomp_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ using Program = pir::Program;

// some outputs like xshape will no longer used after decomp, and those outputs
// will skip checking.
std::unordered_set<std::string> decomp_op_contain_none = {
"pd_op.squeeze", "pd_op.unsqueeze", "pd_op.flatten", "pd_op.batch_norm"};
std::unordered_set<std::string> decomp_op_contain_none = {"pd_op.squeeze",
"pd_op.unsqueeze",
"pd_op.flatten",
"pd_op.batch_norm",
"pd_op.batch_norm_"};

static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
if (std::find(vec.begin(), vec.end(), value) != vec.end()) {
Expand Down Expand Up @@ -275,14 +278,15 @@ std::vector<std::vector<pir::OpResult>> call_decomp_rule(pir::Operation* op) {
}

void DecompProgram::decomp_program() {
std::ostringstream orig_prog_stream;
std::unordered_map<pir::OpResult, int> orig_vars_dict;
for (size_t i = 0; i < src_vars_.size(); i++) {
orig_vars_dict[src_vars_[i]] = static_cast<int>(i);
}
std::ostringstream orig_prog_stream;
program_->Print(orig_prog_stream);
VLOG(4) << "[Prim] Origin program bofore decomp :\n"
VLOG(4) << "[Prim] Origin program before decomp :\n"
<< orig_prog_stream.str();

if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) {
return;
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_decomp(
y = x_hat * reshape<T>(new_scale, stats_shape) +
reshape<T>(new_bias, stats_shape);
}
if (need_cast) {
y = cast<T>(y, org_dtype);
}
Tensor reserve_space;

auto batch_mean_ = assign<T>(batch_mean);
auto inv_std_ = assign<T>(inv_std);
if (need_cast) {
y = cast<T>(y, org_dtype);
}
if (!use_run_stat) {
return std::make_tuple(
y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space);
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,14 +727,24 @@ void BatchNormInferMeta(const MetaTensor& x,
C,
bias.dims()[0]));
}
auto dtype = x.dtype();
if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16 ||
dtype == phi::DataType::UINT16) {
dtype = phi::DataType::FLOAT32;
}

y->set_dims(x_dims);
mean_out->set_dims({C});
mean_out->set_dtype(mean.dtype());
variance_out->set_dims({C});
variance_out->set_dtype(variance.dtype());
if (saved_mean) {
saved_mean->set_dims({C});
saved_mean->set_dtype(dtype);
}
if (saved_variance) {
saved_variance->set_dims({C});
saved_variance->set_dtype(dtype);
}
if (reserve_space) {
reserve_space->set_dims({-1});
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_batch_norm_op_prim_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def initConfig(self):
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
self.check_prim_pir = True


class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp):
Expand Down
3 changes: 1 addition & 2 deletions test/legacy_test/test_batch_norm_op_prim_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def initConfig(self):
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
self.check_prim_pir = True


class TestBatchNormOpNHWCFp16(TestBatchNormOp):
Expand Down Expand Up @@ -161,8 +162,6 @@ def initConfig(self):
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
# Todo(CZ): open this
self.check_prim_pir = False


class TestBatchNormOpNHWCShape2(TestBatchNormOp):
Expand Down

0 comments on commit 125a671

Please sign in to comment.