From 838fba67697a3e40dd722f67f4e893835513e49f Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Mon, 29 Jul 2024 10:18:52 +0800 Subject: [PATCH] [Prim][PIR] Forward decomp the kldiv_loss op (#66510) * forward decomp the kldiv_loss op and surport the dynamic shape for it * change the tol --- .../decomp_interface_gen_op_list.py | 2 + paddle/fluid/primitive/composite/composite.h | 41 ++++++++++++ test/legacy_test/test_kldiv_loss_op.py | 12 +++- .../test_prim_sub_graph_dynamic_shape.py | 63 +++++++++++++++++++ 4 files changed, 116 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 7ac860cbda0d0..757e43eb3fb4e 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -39,6 +39,7 @@ "index_sample", "index_select", "instance_norm", + "kldiv_loss", "layer_norm", "leaky_relu", "lerp", @@ -87,6 +88,7 @@ "index_sample", "index_select", "instance_norm", + "kldiv_loss", "layer_norm", "leaky_relu", "lerp", diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index e087aa0e91e42..71e800d76ba3e 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -1555,6 +1555,47 @@ Tensor log_loss_decomp(const Tensor& input, return term1 - term2; } +template +Tensor kldiv_loss_decomp(const Tensor& x, + const Tensor& label, + const std::string& reduction, + bool log_target) { + bool dynamic_shape = has_dynamic_shape(x.shape()); + Tensor loss; + if (log_target) { + loss = exp(label) * (label - x); + } else { + Tensor output = label * (log(label) - x); + Tensor zero = full_scalar(0.0, label.dtype()); + Tensor zeros; + if (dynamic_shape) { + zeros = backend::full_with_tensor(shape(x), 0, x.dtype()); + } else { + zeros = full(x.shape(), 0, x.dtype()); + } + loss = where(label > zero, output, zeros); + } + + if (reduction == "batchmean") { + if (x.shape().size() > 0) { + if (dynamic_shape) { + return sum(loss) / get_slice(shape(x), 0); + } else { + return sum(loss) / x.shape()[0]; + } + } else { + return sum(loss); + } + } + if (reduction == "mean") { + return mean_decomp(loss, {}, false); + } + if (reduction == "sum") { + return sum(loss); + } + return loss; +} + template Tensor softsign_decomp(const Tensor& x) { // softsign = x / (1 + abs(x)) diff --git a/test/legacy_test/test_kldiv_loss_op.py b/test/legacy_test/test_kldiv_loss_op.py index dbf421ba263e4..2507f77a3c2e6 100644 --- a/test/legacy_test/test_kldiv_loss_op.py +++ b/test/legacy_test/test_kldiv_loss_op.py @@ -46,6 +46,8 @@ def setUp(self): self.initTestCase() self.op_type = 'kldiv_loss' self.python_api = kl_div + self.prim_op_type = "comp" + self.public_python_api = paddle.nn.functional.kl_div x = np.random.uniform(-10, 10, self.x_shape).astype('float64') target = np.random.uniform(-10, 10, self.x_shape).astype('float64') @@ -62,10 +64,16 @@ def setUp(self): self.outputs = {'Loss': loss.astype('float64')} def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_prim_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Loss', no_grad_set={"Target"}, check_pir=True) + self.check_grad( + ['X'], + 'Loss', + no_grad_set={"Target"}, + check_pir=True, + check_prim_pir=True, + ) def initTestCase(self): self.x_shape = (4, 5, 5) diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index b7963464817b7..313c30373f34b 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -188,6 +188,18 @@ def log_loss_net(inputs, labels): return paddle.nn.functional.log_loss(inputs, labels) +def kldiv_loss_net1(x, target): + return paddle.nn.functional.kl_div(x, target, "batchmean", False) + + +def kldiv_loss_net2(x, target): + return paddle.nn.functional.kl_div(x, target, "none", False) + + +def kldiv_loss_net3(x, target): + return paddle.nn.functional.kl_div(x, target, "batchmean", True) + + class TestPrimBase(unittest.TestCase): def setUp(self): np.random.seed(2023) @@ -808,5 +820,56 @@ def setUp(self): self.tol = 1e-5 +class TestPrimKLDivLoss1(TestPrimTwo): + def setUp(self): + np.random.seed(2023) + self.shape_x = [40, 20, 50] + self.shape_y = [40, 20, 50] + self.dtype_x = "float32" + self.dtype_y = "float32" + self.init_x_shape = [None, None, 50] + self.init_y_shape = [None, None, 50] + self.x = np.random.uniform(-10, 10, self.shape_x).astype(self.dtype_x) + self.y = np.random.uniform(-10, 10, self.shape_y).astype(self.dtype_y) + self.net = kldiv_loss_net1 + self.necessary_ops = "pd_op.kldiv_loss" + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimKLDivLoss2(TestPrimTwo): + def setUp(self): + np.random.seed(2023) + self.shape_x = [40, 20, 50] + self.shape_y = [40, 20, 50] + self.dtype_x = "float32" + self.dtype_y = "float32" + self.init_x_shape = [None, None, 50] + self.init_y_shape = [None, None, 50] + self.x = np.random.uniform(-10, 10, self.shape_x).astype(self.dtype_x) + self.y = np.random.uniform(-10, 10, self.shape_y).astype(self.dtype_y) + self.net = kldiv_loss_net2 + self.necessary_ops = "pd_op.kldiv_loss" + self.enable_cinn = False + self.tol = 1e-4 + + +class TestPrimKLDivLoss3(TestPrimTwo): + def setUp(self): + np.random.seed(2023) + self.shape_x = [40, 20, 50] + self.shape_y = [40, 20, 50] + self.dtype_x = "float32" + self.dtype_y = "float32" + self.init_x_shape = [None, None, 50] + self.init_y_shape = [None, None, 50] + self.x = np.random.uniform(-10, 10, self.shape_x).astype(self.dtype_x) + self.y = np.random.uniform(-10, 10, self.shape_y).astype(self.dtype_y) + self.net = kldiv_loss_net3 + self.necessary_ops = "pd_op.kldiv_loss" + self.enable_cinn = False + self.tol = 1e-6 + + if __name__ == "__main__": unittest.main()