Skip to content

Commit

Permalink
op:transpose_op supports bool type (PaddlePaddle#35886)
Browse files Browse the repository at this point in the history
* Pass compat of conv_transpose_bias_mkldnn_fuse_pass

* Fix a bug of strided_slice op, about the axes parameter access memory out of bounds

* Fix a bug of transpose op, about accessing memory out of bounds of the perm param

* op:transpose_op supports bool type
  • Loading branch information
TeslaZhao authored and AnnaTrainingG committed Sep 29, 2021
1 parent 01e71ff commit cc92567
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 7 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>;

#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, bool, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,16 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);

REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
Expand All @@ -373,7 +375,8 @@ REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
Expand All @@ -383,6 +386,7 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/transpose_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
transpose,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
Expand All @@ -92,6 +93,7 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
Expand All @@ -103,6 +105,7 @@ REGISTER_OP_CUDA_KERNEL(

REGISTER_OP_CUDA_KERNEL(
transpose2,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
Expand All @@ -114,6 +117,7 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5499,12 +5499,12 @@ def transpose(x, perm, name=None):
perm[i]-th dimension of `input`.

Args:
x (Tensor): The input Tensor. It is a N-D Tensor of data types float32, float64, int32.
x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float32, float64, int32.
perm (list|tuple): Permute the input according to the data of perm.
name (str): The name of this layer. It is optional.

Returns:
Tensor: A transposed n-D Tensor, with data type being float32, float64, int32, int64.
Tensor: A transposed n-D Tensor, with data type being bool, float32, float64, int32, int64.

For Example:

Expand Down Expand Up @@ -5546,7 +5546,7 @@ def transpose(x, perm, name=None):
return out

check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'transpose')
check_type(perm, 'perm', (list, tuple), 'transpose')
if isinstance(perm, tuple):
Expand Down
97 changes: 95 additions & 2 deletions python/paddle/fluid/tests/unittests/test_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,99 @@ def initTestCase(self):
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)


class TestTransposeOpBool(TestTransposeOp):
def test_check_grad(self):
pass


class TestTransposeOpBool1D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (100, )
self.axis = (0, )
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool2D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool3D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool4D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool5D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool6D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool7D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpBool8D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}


class TestTransposeOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
Expand All @@ -126,9 +219,9 @@ def test_x_Variable_check():
self.assertRaises(TypeError, test_x_Variable_check)

def test_x_dtype_check():
# the Input(x)'s dtype must be one of [float16, float32, float64, int32, int64]
# the Input(x)'s dtype must be one of [bool, float16, float32, float64, int32, int64]
x1 = fluid.layers.data(
name='x1', shape=[10, 5, 3], dtype='bool')
name='x1', shape=[10, 5, 3], dtype='int8')
fluid.layers.transpose(x1, perm=[1, 0, 2])

self.assertRaises(TypeError, test_x_dtype_check)
Expand Down

0 comments on commit cc92567

Please sign in to comment.