Skip to content

Commit

Permalink
【PIR API adaptor No.251-252】 paddle.vision.ops.yolo_box and yolo_loss (
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyulingyue authored Nov 15, 2023
1 parent da53011 commit fb7c80e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def yolo_loss(
... scale_x_y=1.)
"""

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
loss = _C_ops.yolo_loss(
x,
gt_box,
Expand Down Expand Up @@ -365,7 +365,7 @@ def yolo_box(
... clip_bbox=True,
... scale_x_y=1.)
"""
if in_dygraph_mode():
if in_dynamic_or_pir_mode():
boxes, scores = _C_ops.yolo_box(
x,
img_size,
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_yolo_box_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api


def sigmoid(x):
Expand Down Expand Up @@ -141,7 +142,7 @@ def setUp(self):
self.outputs = {'Boxes': boxes, 'Scores': scores}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
Expand Down Expand Up @@ -262,6 +263,7 @@ def test_dygraph(self):


class TestYoloBoxStatic(unittest.TestCase):
@test_with_pir_api
def test_static(self):
x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32')
img_size = paddle.static.data('img_size', [2, 2], 'int32')
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_yolov3_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def l1loss(x, y):
Expand Down Expand Up @@ -438,6 +439,7 @@ def test_dygraph(self):


class TestYolov3LossStatic(unittest.TestCase):
@test_with_pir_api
def test_static(self):
x = paddle.static.data('x', [2, 14, 8, 8], 'float32')
gt_box = paddle.static.data('gt_box', [2, 10, 4], 'float32')
Expand Down

0 comments on commit fb7c80e

Please sign in to comment.