Skip to content

Commit

Permalink
【PIR API adaptor No.190、191】 Migrate paddle.scatter,paddle.scatter_nd…
Browse files Browse the repository at this point in the history
…_add into pir (#58548)
  • Loading branch information
enkilee authored Nov 7, 2023
1 parent 7806bff commit 77e4ada
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 47 deletions.
4 changes: 2 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,7 +3084,7 @@ def scatter(x, index, updates, overwrite=True, name=None):
>>> # [2., 2.],
>>> # [1., 1.]]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.scatter(x, index, updates, overwrite)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -3183,7 +3183,7 @@ def scatter_nd_add(x, index, updates, name=None):
>>> print(output.shape)
[3, 5, 9, 10]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.scatter_nd_add(x, index, updates)
else:
if x.dtype != updates.dtype:
Expand Down
65 changes: 42 additions & 23 deletions test/legacy_test/test_scatter_nd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.pir_utils import test_with_pir_api


def numpy_scatter_nd(ref, index, updates, fun):
Expand Down Expand Up @@ -94,10 +94,12 @@ def _set_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_pir=True)

def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp):
Expand Down Expand Up @@ -125,13 +127,13 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


Expand Down Expand Up @@ -170,10 +172,12 @@ def _set_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_pir=True)

def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex):
Expand Down Expand Up @@ -201,13 +205,13 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


Expand Down Expand Up @@ -249,10 +253,12 @@ def _set_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_pir=True)

def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame):
Expand Down Expand Up @@ -280,13 +286,13 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True
place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


Expand All @@ -312,10 +318,12 @@ def setUp(self):
self.outputs = {'Out': expect_np}

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_pir=True)

def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
)


# Test Python API
Expand Down Expand Up @@ -422,7 +430,7 @@ def testcase5(self):
np.testing.assert_array_equal(gpu_value.numpy(), cpu_value.numpy())
paddle.set_device(device)

@switch_to_static_graph
@test_with_pir_api
def test_static_graph():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand All @@ -434,15 +442,26 @@ def test_static_graph():
val_t = paddle.static.data(
name="val", dtype=val.dtype, shape=val.shape
)
out_t = paddle.scatter_nd_add(x_t, index_t, val_t)
feed = {x_t.name: x, index_t.name: index, val_t.name: val}
fetch = [out_t]

gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
cpu_exe = paddle.static.Executor(paddle.CPUPlace())
cpu_value = cpu_exe.run(feed=feed, fetch_list=fetch)[0]
np.testing.assert_array_equal(gpu_value, cpu_value)
out_t = paddle.scatter_nd_add(x_t, index_t, val_t)
gpu_value = gpu_exe.run(
feed={
'x': x,
'index': index,
'val': val,
},
fetch_list=[out_t],
)
cpu_value = cpu_exe.run(
feed={
'x': x,
'index': index,
'val': val,
},
fetch_list=[out_t],
)
np.testing.assert_array_equal(gpu_value, cpu_value)

test_static_graph()

Expand Down
Loading

0 comments on commit 77e4ada

Please sign in to comment.