Skip to content

Commit

Permalink
[MLU] fix masked_select (PaddlePaddle#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeiyuLau authored Jun 8, 2023
1 parent 4ef39c2 commit 1ba2313
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 87 deletions.
37 changes: 22 additions & 15 deletions paddle/fluid/operators/masked_select_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ class MaskedSelectedMLUKernel : public framework::OpKernel<T> {
Tensor number(framework::TransToPhiDataType(VT::INT32));
void* number_ptr = number.mutable_data<int32_t>({1}, ctx.GetPlace());

out->Resize(mask->dims());
out->mutable_data<T>(ctx.GetPlace());
Tensor mask_select_out;
mask_select_out.mutable_data<T>(mask->dims(), ctx.GetPlace());

MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc mask_desc(*mask);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc masked_select_out_desc(mask_select_out);
MLUCnnl::Mask(ctx,
CNNL_MASKED_SELECT,
input_desc.get(),
Expand All @@ -56,9 +57,23 @@ class MaskedSelectedMLUKernel : public framework::OpKernel<T> {
GetBasePtr(mask),
nullptr,
nullptr,
out_desc.get(),
GetBasePtr(out),
masked_select_out_desc.get(),
GetBasePtr(&mask_select_out),
static_cast<uint32_t*>(number_ptr));
auto stream = ctx.template device_context<MLUDeviceContext>().stream();
Tensor number_cpu;
paddle::framework::TensorCopySync(
number, platform::CPUPlace(), &number_cpu);

out->Resize({number_cpu.data<int>()[0]});
out->mutable_data<T>(ctx.GetPlace());

memory::Copy(ctx.GetPlace(),
GetBasePtr(out),
ctx.GetPlace(),
GetBasePtr(&mask_select_out),
number_cpu.data<int>()[0] * sizeof(T),
stream);
}
};

Expand Down Expand Up @@ -154,15 +169,7 @@ class MaskedSelectedGradMLUKernel : public framework::OpKernel<T> {
out_size_vec[0] * sizeof(int32_t),
stream);

Tensor y_grad_tmp_out;
y_grad_tmp_out.mutable_data<T>({out_size_vec[0]}, ctx.GetPlace());
MLUCnnlTensorDesc y_grad_tmp_out_desc(y_grad_tmp_out);
memory::Copy(ctx.GetPlace(),
GetBasePtr(&y_grad_tmp_out),
ctx.GetPlace(),
GetBasePtr(y_grad),
out_size_vec[0] * sizeof(T),
stream);
MLUCnnlTensorDesc y_grad_desc(*y_grad);

Tensor indices_int32_tmp;
indices_int32_tmp.ShareDataWith(indices_int32_out);
Expand All @@ -177,8 +184,8 @@ class MaskedSelectedGradMLUKernel : public framework::OpKernel<T> {
mode,
indices_int32_tmp_desc.get(),
GetBasePtr(&indices_int32_tmp),
y_grad_tmp_out_desc.get(),
GetBasePtr(&y_grad_tmp_out),
y_grad_desc.get(),
GetBasePtr(y_grad),
nullptr,
nullptr,
x_grad_desc.get(),
Expand Down
155 changes: 83 additions & 72 deletions python/paddle/fluid/tests/unittests/mlu/test_masked_select_op_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,88 +11,98 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import print_function

import unittest
import sys

sys.path.append("..")
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
import paddle

paddle.enable_static()


def np_masked_select(shape, x, mask):

def np_masked_select(x, mask):
result = np.empty(shape=(0), dtype=x.dtype)
sum = 0
for index, (ele, ma) in enumerate(zip(np.nditer(x), np.nditer(mask))):
for ele, ma in zip(np.nditer(x), np.nditer(mask)):
if ma:
sum = sum + 1
result = np.append(result, ele)
for index, (ele, ma) in enumerate(zip(np.nditer(x), np.nditer(mask))):
if index >= sum:
result = np.append(result, 0)
result = np.reshape(result, shape)
return result


return result.flatten()


class TestMaskedSelectOp(OpTest):

def set_mlu(self):
self.__class__.use_mlu = True

def setUp(self):
self.set_mlu()
self.init()
self.__class__.use_mlu = True
self.init_dtype()
self.place = paddle.device.MLUPlace(0)
self.op_type = "masked_select"
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype('float32')
x = np.random.random(self.shape).astype(self.dtype)
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(self.shape, x, mask)
out = np_masked_select(x, mask)
self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out}

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Y')

def init(self):
self.shape = (50, 3)



def init_dtype(self):
self.dtype = np.float32


class TestMaskedSelectOp1(TestMaskedSelectOp):

def init(self):
self.shape = (6, 8, 9, 18)


class TestMaskedSelectOp2(TestMaskedSelectOp):

def init(self):
self.shape = (168,)


self.shape = (168, )

class TestMaskedSelectOpFp16(TestMaskedSelectOp):

def init_dtype(self):
self.dtype = np.float16

def test_check_grad(self):
x_grad = self.inputs['Mask'].astype(self.dtype)
x_grad = x_grad * (1 / x_grad.sum())
self.check_grad_with_place(self.place, ['X'],
'Y',
user_defined_grads=[x_grad])


@skip_check_grad_ci(reason="get_numeric_gradient not support int32")
class TestMaskedSelectOpInt32(TestMaskedSelectOp):

def init_dtype(self):
self.dtype = np.int32

def test_check_grad(self):
pass


class TestMaskedSelectOpFp16(TestMaskedSelectOp):
def init_dtype(self):
self.dtype = np.float16

def test_check_grad(self):
x_grad = self.inputs['Mask'].astype(self.dtype)
x_grad = x_grad * (1 / x_grad.size)
self.check_grad_with_place(
self.place, ['X'], 'Y', user_defined_grads=[x_grad]
)







class TestMaskedSelectAPI(unittest.TestCase):

def test_imperative_mode(self):
paddle.disable_static()
shape = (88, 6, 8)
Expand All @@ -101,60 +111,61 @@ def test_imperative_mode(self):
x = paddle.to_tensor(np_x)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
np_out = np_masked_select(shape, np_x, np_mask)
np_out = np_masked_select(np_x, np_mask)
self.assertEqual(np.allclose(out.numpy(), np_out), True)
paddle.enable_static()

def test_static_mode(self):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))

out = paddle.masked_select(x, mask)
np_out = np_masked_select(shape, np_x, np_mask)

np_out = np_masked_select(np_x, np_mask)
exe = paddle.static.Executor(place=paddle.device.MLUPlace(0))

res = exe.run(
paddle.static.default_main_program(),
feed={"x": np_x, "mask": np_mask},
fetch_list=[out],
)

res = exe.run(paddle.static.default_main_program(),
feed={
"x": np_x,
"mask": np_mask
},
fetch_list=[out])
self.assertEqual(np.allclose(res, np_out), True)


class TestMaskedSelectError(unittest.TestCase):

def test_error(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):

with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):

shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
mask_float = paddle.fluid.data(
shape=shape, dtype='float32', name='mask_float'
)
mask_float = paddle.fluid.data(shape=shape,
dtype='float32',
name='mask_float')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))

def test_x_type():
paddle.masked_select(np_x, mask)

self.assertRaises(TypeError, test_x_type)

def test_mask_type():
paddle.masked_select(x, np_mask)

self.assertRaises(TypeError, test_mask_type)

def test_mask_dtype():
paddle.masked_select(x, mask_float)

self.assertRaises(TypeError, test_mask_dtype)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1ba2313

Please sign in to comment.