Skip to content

Commit

Permalink
add complex support for channel_shuffle, channel_shuffle_grad, shuffl…
Browse files Browse the repository at this point in the history
…e_batch and shuffle_batch_grad
  • Loading branch information
zyt1024 committed Jan 28, 2024
1 parent d291727 commit 3e6674a
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 33 deletions.
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/channel_shuffle_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle_grad,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/channel_shuffle_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/shuffle_batch_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ PD_REGISTER_KERNEL(shuffle_batch_grad,
float,
double,
int32_t,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/shuffle_batch_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ PD_REGISTER_KERNEL(shuffle_batch,
float,
double,
int32_t,
int64_t) {
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/channel_shuffle_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(channel_shuffle_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/channel_shuffle_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(channel_shuffle,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/shuffle_batch_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,7 @@ PD_REGISTER_KERNEL(shuffle_batch_grad,
float,
double,
int32_t,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ PD_REGISTER_KERNEL(shuffle_batch,
float,
double,
int32_t,
int64_t) {
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def shuffle_batch(x, seed=None):
Out.dims = [4, 2]
Args:
x (Tensor): The input Tensor. The input Tensor is a N-D LoDTensor with type int, float32 or float64.
x (Tensor): The input Tensor. The input Tensor is a N-D LoDTensor with type int, float32, float64, complex64 or complex128.
seed (None|int|Tensor, optional): The start up seed. If set, seed will be set as the start up seed of shuffle engine.
If not set(Default), start up seed of shuffle engine will be generated randomly. Default: None.
Expand Down
9 changes: 7 additions & 2 deletions python/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None):
See more details in :ref:`api_paddle_nn_ChannelShuffle`.
Parameters:
x (Tensor): 4-D tensor, the data type should be float32 or float64.
x (Tensor): 4-D tensor, the data type should be float32, float64, complex64 or complex128.
groups (int): Number of groups to divide channels in.
data_format (str, optional): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width].
name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -523,7 +523,12 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None):
return _C_ops.channel_shuffle(x, groups, data_format)

helper = LayerHelper("channel_shuffle", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'channel_shuffle')
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'complex64', 'complex128'],
'channel_shuffle',
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="channel_shuffle",
Expand Down
117 changes: 97 additions & 20 deletions test/legacy_test/test_channel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def setUp(self):
groups = 3

x = np.random.random(shape).astype(self.dtype)
if self.dtype == 'complex64' or self.dtype == 'complex128':
x = (np.random.random(shape) + 1j * np.random.random(shape)).astype(
self.dtype
)

npresult = channel_shuffle_np(x, groups, self.format)

self.inputs = {'X': x}
Expand All @@ -82,13 +87,47 @@ def init_data_format(self):
self.format = "NHWC"


class TestChannelShuffleOp_complex64(TestChannelShuffleOp):
def init_dtype(self):
self.dtype = 'complex64'


class TestChannelLast_complex64(TestChannelLast):
def init_dtype(self):
self.dtype = 'complex64'


class TestChannelShuffleOp_complex128(TestChannelShuffleOp):
def init_dtype(self):
self.dtype = 'complex128'


class TestChannelLast_complex128(TestChannelLast):
def init_dtype(self):
self.dtype = 'complex128'


class TestChannelShuffleAPI(unittest.TestCase):
def setUp(self):
self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float64")
self.init_dtype()
self.x_2_np = np.random.random([2, 4, 4, 9]).astype(self.dtype)
self.x_1_np = np.random.random([2, 9, 4, 4]).astype(self.dtype)
if self.dtype == 'complex64' or self.dtype == 'complex128':
self.x_2_np = (
np.random.random([2, 4, 4, 9])
+ 1j * np.random.random([2, 4, 4, 9])
).astype(self.dtype)
self.x_1_np = (
np.random.random([2, 9, 4, 4])
+ 1j * np.random.random([2, 9, 4, 4])
).astype(self.dtype)

self.out_2_np = channel_shuffle_np(self.x_2_np, 3, "NHWC")
self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float64")
self.out_1_np = channel_shuffle_np(self.x_1_np, 3)

def init_dtype(self):
self.dtype = 'float64'

@test_with_pir_api
def test_static_graph_functional(self):
with paddle.static.program_guard(
Expand All @@ -101,7 +140,7 @@ def test_static_graph_functional(self):

paddle.enable_static()
x_1 = paddle.static.data(
name="x", shape=[2, 9, 4, 4], dtype="float64"
name="x", shape=[2, 9, 4, 4], dtype=self.dtype
)
out_1 = F.channel_shuffle(x_1, 3)

Expand All @@ -128,7 +167,7 @@ def test_static_graph_layer(self):

paddle.enable_static()
x_1 = paddle.static.data(
name="x", shape=[2, 9, 4, 4], dtype="float64"
name="x", shape=[2, 9, 4, 4], dtype=self.dtype
)
# init instance
ps_1 = paddle.nn.ChannelShuffle(3)
Expand Down Expand Up @@ -157,7 +196,7 @@ def test_static_graph_functional_new(self):

paddle.enable_static()
x_2 = paddle.static.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64"
name="x2", shape=[2, 4, 4, 9], dtype=self.dtype
)
out_2 = F.channel_shuffle(x_2, 3, "NHWC")

Expand All @@ -183,7 +222,7 @@ def test_static_graph_layer_new(self):

paddle.enable_static()
x_2 = paddle.static.data(
name="x2", shape=[2, 4, 4, 9], dtype="float64"
name="x2", shape=[2, 4, 4, 9], dtype=self.dtype
)
# init instance
ps_2 = paddle.nn.ChannelShuffle(3, "NHWC")
Expand All @@ -209,7 +248,11 @@ def run_dygraph(self, groups, data_format):
if data_format == "NHWC":
shape = [n, h, w, c]

x = np.random.random(shape).astype("float64")
x = np.random.random(shape).astype(self.dtype)
if self.dtype == 'complex64' or self.dtype == 'complex128':
x = (np.random.random(shape) + 1j * np.random.random(shape)).astype(
self.dtype
)

npresult = channel_shuffle_np(x, groups, data_format)

Expand Down Expand Up @@ -246,35 +289,63 @@ def test_dygraph2(self):
self.run_dygraph(3, "NHWC")


class TestChannelShuffleAPI_complex64(TestChannelShuffleAPI):
def init_dtype(self):
self.dtype = 'complex64'


class TestChannelShuffleAPI_complex128(TestChannelShuffleAPI):
def init_dtype(self):
self.dtype = 'complex128'


class TestChannelShuffleError(unittest.TestCase):
def setUp(self):
self.init_dtype()
self.x = np.random.random([2, 9, 4, 4]).astype(self.dtype)
self.x_other = np.random.random([9, 4, 4]).astype(self.dtype)
if self.dtype == 'complex64' or self.dtype == 'complex128':
self.x = (
np.random.random([2, 9, 4, 4])
+ 1j * np.random.random([2, 9, 4, 4])
).astype(self.dtype)
self.x_other = (
np.random.random([9, 4, 4]) + 1j * np.random.random([9, 4, 4])
).astype(self.dtype)

def init_dtype(self):
self.dtype = 'float64'

@test_with_pir_api
def test_error_functional(self):
def error_input():
with paddle.base.dygraph.guard():
x = np.random.random([9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(paddle.to_tensor(x), 3)
channel_shuffle = F.channel_shuffle(
paddle.to_tensor(self.x_other), 3
)

self.assertRaises(ValueError, error_input)

def error_groups_1():
with paddle.base.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(paddle.to_tensor(x), 3.33)
channel_shuffle = F.channel_shuffle(
paddle.to_tensor(self.x), 3.33
)

self.assertRaises(TypeError, error_groups_1)

def error_groups_2():
with paddle.base.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(paddle.to_tensor(x), -1)
channel_shuffle = F.channel_shuffle(
paddle.to_tensor(self.x), -1
)

self.assertRaises(ValueError, error_groups_2)

def error_data_format():
with paddle.base.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
channel_shuffle = F.channel_shuffle(
paddle.to_tensor(x), 3, "WOW"
paddle.to_tensor(self.x), 3, "WOW"
)

self.assertRaises(ValueError, error_data_format)
Expand All @@ -283,34 +354,40 @@ def error_data_format():
def test_error_layer(self):
def error_input_layer():
with paddle.base.dygraph.guard():
x = np.random.random([9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(3)
cs(paddle.to_tensor(x))
cs(paddle.to_tensor(self.x_other))

self.assertRaises(ValueError, error_input_layer)

def error_groups_layer_1():
with paddle.base.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(3.33)

self.assertRaises(TypeError, error_groups_layer_1)

def error_groups_layer_2():
with paddle.base.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(-1)

self.assertRaises(ValueError, error_groups_layer_2)

def error_data_format_layer():
with paddle.base.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64")
cs = paddle.nn.ChannelShuffle(3, "MEOW")

self.assertRaises(ValueError, error_data_format_layer)


class TestChannelShuffleError_complex64(TestChannelShuffleError):
def init_dtype(self):
self.dtype = 'complex64'


class TestChannelShuffleError_complex128(TestChannelShuffleError):
def init_dtype(self):
self.dtype = 'complex128'


class TestChannelShuffleFP16OP(TestChannelShuffleOp):
def init_dtype(self):
self.dtype = np.float16
Expand Down
35 changes: 33 additions & 2 deletions test/legacy_test/test_shuffle_batch_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
class TestShuffleBatchOpBase(OpTest):
def gen_random_array(self, shape, low=0, high=1):
rnd = (high - low) * np.random.random(shape) + low
if self.dtype == np.complex64 or self.dtype == np.complex128:
rnd = ((high - low) * np.random.random(shape) + low) + 1j * (
(high - low) * np.random.random(shape) + low
)
# print(rnd)
return rnd.astype(self.dtype)

def get_shape(self):
Expand All @@ -38,7 +43,7 @@ def _get_places(self):

def setUp(self):
self.op_type = 'shuffle_batch'
self.dtype = np.float64
self.dtype = self.get_dtype()
self.shape = self.get_shape()
x = self.gen_random_array(self.shape)
seed = np.random.random_integers(low=10, high=100, size=(1,)).astype(
Expand Down Expand Up @@ -72,9 +77,15 @@ def sort_array(self, array):
shape = array.shape
new_shape = [-1, shape[-1]]
arr_list = np.reshape(array, new_shape).tolist()
arr_list.sort(key=lambda x: x[0])
if array.dtype == np.complex64 or array.dtype == np.complex128:
arr_list.sort(key=lambda x: (x[0].real, x[0].imag))
else:
arr_list.sort(key=lambda x: x[0])
return np.reshape(np.array(arr_list), shape)

def get_dtype(self):
return np.float64

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_dygraph=False)

Expand All @@ -84,5 +95,25 @@ def get_shape(self):
return (4, 30)


class TestShuffleBatchOpBase_complex64(TestShuffleBatchOpBase):
def get_dtype(self):
return np.complex64

def test_check_grad(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.006, check_dygraph=False
)


class TestShuffleBatchOpBase_complex128(TestShuffleBatchOpBase):
def get_dtype(self):
return np.complex128

def test_check_grad(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.006, check_dygraph=False
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3e6674a

Please sign in to comment.