Skip to content

Commit

Permalink
[PIR]Migrate embedding and fused_softmax_mask_upper_triangle to pir (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored Sep 16, 2023
1 parent 1812999 commit 2c6d4e1
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
'softmax',
'silu',
'elementwise_pow',
'embedding',
'fused_softmax_mask_upper_triangle',
'slice',
'transpose',
Expand Down Expand Up @@ -79,6 +80,7 @@
'softmax',
'silu',
'elementwise_pow',
'embedding',
'fused_softmax_mask_upper_triangle',
'slice',
'transpose',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _legacy_C_ops
from paddle import _C_ops
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode


def softmax_mask_fuse_upper_triangle(x):
Expand Down Expand Up @@ -68,8 +68,8 @@ def softmax_mask_fuse_upper_triangle(x):
[0.02280738, 0.03144657, 0.02892209, ..., 0.03885521,
0.03342311, 0.02842640]]]])
"""
if in_dynamic_mode():
out = _legacy_C_ops.fused_softmax_mask_upper_triangle(x)
if in_dynamic_or_pir_mode():
out = _C_ops.fused_softmax_mask_upper_triangle(x)
return out

helper = LayerHelper('fused_softmax_mask_upper_triangle', **locals())
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...base.data_feeder import check_variable_and_dtype
from ...base.layer_helper import LayerHelper
from ...common_ops_import import Variable
from ...framework import in_dynamic_mode
from ...framework import in_dynamic_mode, in_dynamic_or_pir_mode

__all__ = []

Expand Down Expand Up @@ -224,7 +224,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
)
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.embedding(x, weight, padding_idx, sparse)
else:
helper = LayerHelper('embedding', **locals())
Expand Down
35 changes: 27 additions & 8 deletions test/legacy_test/test_lookup_table_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,16 @@ def id_dtype(self):
return "int64"

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_new_ir=True)

def test_check_grad(self):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'), check_cinn=True)
self.check_grad(
['W'],
'Out',
no_grad_set=set('Ids'),
check_cinn=True,
check_new_ir=True,
)


class TestLookupTableOpInt16(OpTest):
Expand Down Expand Up @@ -93,10 +99,16 @@ def setUp(self):
self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))}

def test_check_output(self):
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_new_ir=True)

def test_check_grad(self):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'), check_cinn=True)
self.check_grad(
['W'],
'Out',
no_grad_set=set('Ids'),
check_cinn=True,
check_new_ir=True,
)


@skip_check_grad_ci(
Expand All @@ -110,7 +122,7 @@ def test_check_output(self):
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_new_ir=True)


@skip_check_grad_ci(
Expand All @@ -125,7 +137,7 @@ def test_check_output(self):
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': padding_idx}
self.check_output(check_cinn=True)
self.check_output(check_cinn=True, check_new_ir=True)


class TestLookupTableWIsSelectedRows(unittest.TestCase):
Expand Down Expand Up @@ -203,6 +215,7 @@ def init_data(self):
self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")

def get_w_grad(self, is_sparse):
paddle.enable_static()
self.init_data()
main_program = base.Program()
with base.program_guard(main_program, base.Program()):
Expand Down Expand Up @@ -250,6 +263,7 @@ def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):

class TestLookupTableApi(unittest.TestCase):
def test_api(self):
paddle.enable_static()
x = paddle.static.data(name='x', shape=[-1, 20], dtype='int64')
emb = paddle.static.nn.embedding(input=x, size=[128, 64])

Expand Down Expand Up @@ -341,12 +355,17 @@ def id_dtype(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_cinn=True)
self.check_output_with_place(place, check_cinn=True, check_new_ir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['W'], 'Out', no_grad_set=set('Ids'), check_cinn=True
place,
['W'],
'Out',
no_grad_set=set('Ids'),
check_cinn=True,
check_new_ir=True,
)


Expand Down
12 changes: 8 additions & 4 deletions test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def setUp(self):
self.outputs = {'Out': rst}

def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
self.check_output_with_place(core.CUDAPlace(0), check_new_ir=True)

def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
self.check_grad_with_place(
core.CUDAPlace(0), ["X"], "Out", check_new_ir=True
)


@unittest.skipIf(
Expand All @@ -70,13 +72,15 @@ def setUp(self):

def test_check_output(self):
try:
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(core.CPUPlace(), check_new_ir=True)
except (NotImplementedError, RuntimeError):
pass

def test_check_grad(self):
try:
self.check_grad_with_place(core.CPUPlace(), ["X"], "Out")
self.check_grad_with_place(
core.CPUPlace(), ["X"], "Out", check_new_ir=True
)
except (NotImplementedError, RuntimeError):
pass

Expand Down

0 comments on commit 2c6d4e1

Please sign in to comment.