Skip to content

Commit

Permalink
gn decomp rule supports rank 3 (#63056)
Browse files Browse the repository at this point in the history
* gn decomp rule supports rank 3

* fix code

* update primitive ops list

* fix code

* update list

* fix bug
  • Loading branch information
cyber-pioneer authored Mar 28, 2024
1 parent b1b0726 commit 602d2ba
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/primitive/base/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const std::set<std::string>& GetPrimitiveOpNames() {
"pd_op.sum",
"pd_op.abs",
"pd_op.assign",
"pd_op.assign_value",
"pd_op.concat",
"pd_op.elementwise_pow",
"pd_op.rsqrt",
Expand All @@ -58,6 +59,8 @@ const std::set<std::string>& GetPrimitiveOpNames() {
"pd_op.min",
"pd_op.maximum",
"pd_op.minimum",
"pd_op.argmax",
"pd_op.argmin",
"pd_op.prod",
"pd_op.roll",
"pd_op.scatter",
Expand Down Expand Up @@ -100,11 +103,15 @@ const std::set<std::string>& GetPrimitiveOpNames() {
"pd_op.data",
"builtin.shadow_output",
/* skip some special ops */
"pd_op.conv2d",
"pd_op.pad3d",
"pd_op.nearest_interp",
"pd_op.squeeze",
"pd_op.unsqueeze",
"pd_op.select_input",
"pd_op.top_p_sampling",
"pd_op.tril",
"pd_op.triu",
"cf.yield",
"pd_op.increment_",
};
Expand Down
28 changes: 20 additions & 8 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -843,19 +843,29 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
// TODO(chengyanfu): support NHWC data format
PADDLE_THROW(phi::errors::Unimplemented("Only support NCHW format."));
}
size_t rank = x.shape().size();
if (rank != 3 && rank != 4) {
PADDLE_THROW(
phi::errors::Unimplemented("Only support NCHW format in rank 3 or 4."));
}

auto org_dtype = x.dtype();
Tensor x_cast = x;

bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, DataType::FLOAT32);
}
if (rank == 3) {
x_cast = unsqueeze<T>(x_cast, {-1});
}
Tensor x_dim_t;
Tensor out, mean_, var_;
if (has_dynamic_shape(x.shape())) {
Tensor x_dim = shape<T>(x);
if (has_dynamic_shape(x_cast.shape())) {
Tensor x_dim_t = shape<T>(x_cast);
std::vector<int64_t> one_axis(1, 1);
Tensor x_shape = get_slice<T>(x_dim, 0) * groups;
Tensor dim_1 = full<T>({1}, -1, x_dim.type());
Tensor x_shape = get_slice<T>(x_dim_t, 0) * groups;
Tensor dim_1 = full<T>({1}, -1, x_dim_t.type());
x_shape = concat<T>({x_shape, dim_1});
x_cast = backend::reshape<T>(x_cast, x_shape);
mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true);
Expand All @@ -868,9 +878,9 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
Tensor var_inv =
rsqrt<T>(var_ + full<T>(empty_shape, epsilon, var_.dtype()));
Tensor res = (x_cast - mean_) * var_inv;
out = backend::reshape<T>(res, x_dim);
out = backend::reshape<T>(res, x_dim_t);
} else {
auto x_dim = x.shape();
auto x_dim = x_cast.shape();
std::vector<int64_t> one_axis(1, 1);

std::vector<int64_t> x_shape{x_dim[0] * groups, -1};
Expand Down Expand Up @@ -903,8 +913,7 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
}
Tensor mean_out, var_out;
if (has_dynamic_shape(x.shape())) {
Tensor x_dim = shape<T>(x);
Tensor x_shape = get_slice<T>(x_dim, 0);
Tensor x_shape = get_slice<T>(x_dim_t, 0);
Tensor dim_1 = full<T>({1}, groups, x_shape.type());
x_shape = concat<T>({x_shape, dim_1});
mean_out = backend::reshape<T>(mean_, x_shape);
Expand All @@ -918,6 +927,9 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
if (need_cast) {
out = cast<T>(out, org_dtype);
}
if (rank == 3) {
out = squeeze<T>(out, {-1});
}

return std::make_tuple(out, mean_out, var_out);
}
Expand Down
4 changes: 2 additions & 2 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ class TestPrimGroupNorm3(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [50, 640, 10, 20]
self.init_x_shape = [None, 640, None, None]
self.x_shape = [50, 640, 10]
self.init_x_shape = [None, 640, None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = group_norm_net3
self.necessary_ops = "pd_op.group_norm"
Expand Down

0 comments on commit 602d2ba

Please sign in to comment.