Skip to content

Commit

Permalink
[Prim][PIR] Forward decomp the kldiv_loss op (#66510)
Browse files Browse the repository at this point in the history
* forward decomp the kldiv_loss op and surport the dynamic shape for it

* change the tol
  • Loading branch information
zeroRains authored Jul 29, 2024
1 parent c39ac39 commit 838fba6
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"index_sample",
"index_select",
"instance_norm",
"kldiv_loss",
"layer_norm",
"leaky_relu",
"lerp",
Expand Down Expand Up @@ -87,6 +88,7 @@
"index_sample",
"index_select",
"instance_norm",
"kldiv_loss",
"layer_norm",
"leaky_relu",
"lerp",
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,47 @@ Tensor log_loss_decomp(const Tensor& input,
return term1 - term2;
}

template <typename T>
Tensor kldiv_loss_decomp(const Tensor& x,
const Tensor& label,
const std::string& reduction,
bool log_target) {
bool dynamic_shape = has_dynamic_shape(x.shape());
Tensor loss;
if (log_target) {
loss = exp<T>(label) * (label - x);
} else {
Tensor output = label * (log<T>(label) - x);
Tensor zero = full_scalar<T>(0.0, label.dtype());
Tensor zeros;
if (dynamic_shape) {
zeros = backend::full_with_tensor<T>(shape<T>(x), 0, x.dtype());
} else {
zeros = full<T>(x.shape(), 0, x.dtype());
}
loss = where<T>(label > zero, output, zeros);
}

if (reduction == "batchmean") {
if (x.shape().size() > 0) {
if (dynamic_shape) {
return sum<T>(loss) / get_slice<T>(shape<T>(x), 0);
} else {
return sum<T>(loss) / x.shape()[0];
}
} else {
return sum<T>(loss);
}
}
if (reduction == "mean") {
return mean_decomp<T>(loss, {}, false);
}
if (reduction == "sum") {
return sum<T>(loss);
}
return loss;
}

template <typename T>
Tensor softsign_decomp(const Tensor& x) {
// softsign = x / (1 + abs(x))
Expand Down
12 changes: 10 additions & 2 deletions test/legacy_test/test_kldiv_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def setUp(self):
self.initTestCase()
self.op_type = 'kldiv_loss'
self.python_api = kl_div
self.prim_op_type = "comp"
self.public_python_api = paddle.nn.functional.kl_div
x = np.random.uniform(-10, 10, self.x_shape).astype('float64')
target = np.random.uniform(-10, 10, self.x_shape).astype('float64')

Expand All @@ -62,10 +64,16 @@ def setUp(self):
self.outputs = {'Loss': loss.astype('float64')}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Loss', no_grad_set={"Target"}, check_pir=True)
self.check_grad(
['X'],
'Loss',
no_grad_set={"Target"},
check_pir=True,
check_prim_pir=True,
)

def initTestCase(self):
self.x_shape = (4, 5, 5)
Expand Down
63 changes: 63 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def log_loss_net(inputs, labels):
return paddle.nn.functional.log_loss(inputs, labels)


def kldiv_loss_net1(x, target):
return paddle.nn.functional.kl_div(x, target, "batchmean", False)


def kldiv_loss_net2(x, target):
return paddle.nn.functional.kl_div(x, target, "none", False)


def kldiv_loss_net3(x, target):
return paddle.nn.functional.kl_div(x, target, "batchmean", True)


class TestPrimBase(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
Expand Down Expand Up @@ -808,5 +820,56 @@ def setUp(self):
self.tol = 1e-5


class TestPrimKLDivLoss1(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
self.shape_x = [40, 20, 50]
self.shape_y = [40, 20, 50]
self.dtype_x = "float32"
self.dtype_y = "float32"
self.init_x_shape = [None, None, 50]
self.init_y_shape = [None, None, 50]
self.x = np.random.uniform(-10, 10, self.shape_x).astype(self.dtype_x)
self.y = np.random.uniform(-10, 10, self.shape_y).astype(self.dtype_y)
self.net = kldiv_loss_net1
self.necessary_ops = "pd_op.kldiv_loss"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimKLDivLoss2(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
self.shape_x = [40, 20, 50]
self.shape_y = [40, 20, 50]
self.dtype_x = "float32"
self.dtype_y = "float32"
self.init_x_shape = [None, None, 50]
self.init_y_shape = [None, None, 50]
self.x = np.random.uniform(-10, 10, self.shape_x).astype(self.dtype_x)
self.y = np.random.uniform(-10, 10, self.shape_y).astype(self.dtype_y)
self.net = kldiv_loss_net2
self.necessary_ops = "pd_op.kldiv_loss"
self.enable_cinn = False
self.tol = 1e-4


class TestPrimKLDivLoss3(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
self.shape_x = [40, 20, 50]
self.shape_y = [40, 20, 50]
self.dtype_x = "float32"
self.dtype_y = "float32"
self.init_x_shape = [None, None, 50]
self.init_y_shape = [None, None, 50]
self.x = np.random.uniform(-10, 10, self.shape_x).astype(self.dtype_x)
self.y = np.random.uniform(-10, 10, self.shape_y).astype(self.dtype_y)
self.net = kldiv_loss_net3
self.necessary_ops = "pd_op.kldiv_loss"
self.enable_cinn = False
self.tol = 1e-6


if __name__ == "__main__":
unittest.main()

0 comments on commit 838fba6

Please sign in to comment.