Skip to content

Commit

Permalink
【PIR API adaptor No.39、123】Migrate label_smooth & class_center_sample…
Browse files Browse the repository at this point in the history
… into pir (PaddlePaddle#58693)

* add common.py

* rm default_main_program && create new func

* add static.program_guard

* rm dynamic_and_pir_mode_test func
  • Loading branch information
DrRyanHuang authored Nov 7, 2023
1 parent 19b3662 commit 6917abb
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,7 +2236,7 @@ class centers and the shape of sampled_class_center will be [num_positive_class_
if (seed is None or seed == 0) and default_main_program().random_seed != 0:
seed = default_main_program().random_seed

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.class_center_sample(
label,
num_classes,
Expand Down
13 changes: 9 additions & 4 deletions test/legacy_test/test_class_center_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from op_test import OpTest, paddle_static_guard

import paddle
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def class_center_sample_numpy(label, classes_list, num_samples):
Expand Down Expand Up @@ -118,7 +119,9 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(no_check_set=['SampledLocalClassCenter'])
self.check_output(
no_check_set=['SampledLocalClassCenter'], check_pir=True
)


class TestClassCenterSampleOpINT32(TestClassCenterSampleOp):
Expand Down Expand Up @@ -160,9 +163,12 @@ def test_static(self):
for place in self.places:
self.check_static_result(place=place)

@test_with_pir_api
def check_static_result(self, place):
with paddle_static_guard():
with program_guard(Program(), Program()):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
label_np = np.random.randint(
0, self.num_classes, (self.batch_size,), dtype=self.dtype
)
Expand All @@ -185,7 +191,6 @@ def check_static_result(self, place):
)
exe = paddle.base.Executor(place)
[remapped_label_res, sampled_class_index_res] = exe.run(
paddle.base.default_main_program(),
feed={'label': label_np},
fetch_list=[remapped_label, sampled_class_index],
)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_label_smooth_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle.base.dygraph as dg
import paddle.nn.functional as F
from paddle import base
from paddle.pir_utils import test_with_pir_api


class LabelSmoothTestCase(unittest.TestCase):
Expand Down Expand Up @@ -88,6 +89,7 @@ def paddle_dygraph_layer(self):
y_np = y_var.numpy()
return y_np

@test_with_pir_api
def _test_equivalence(self, place):
place = base.CPUPlace()
result1 = self.base_layer(place)
Expand Down
8 changes: 4 additions & 4 deletions test/legacy_test/test_label_smooth_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_pir=True)


@unittest.skipIf(
Expand Down Expand Up @@ -77,11 +77,11 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ["X"], "Out")
self.check_grad_with_place(place, ["X"], "Out", check_pir=True)


class TestLabelSmoothFP16OP(TestLabelSmoothOp):
Expand Down

0 comments on commit 6917abb

Please sign in to comment.