Skip to content

Commit

Permalink
Add overlap_add, sign tests
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Apr 12, 2023
1 parent 57069f8 commit be66ac9
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 13 deletions.
1 change: 1 addition & 0 deletions paddle/phi/kernels/funcs/eigen/sign.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct EigenSign<Eigen::GpuDevice, T> {
template struct EigenSign<Eigen::GpuDevice, float>;
template struct EigenSign<Eigen::GpuDevice, double>;
template struct EigenSign<Eigen::GpuDevice, dtype::float16>;
template struct EigenSign<Eigen::GpuDevice, dtype::bfloat16>;

} // namespace funcs
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,6 @@ PD_REGISTER_KERNEL(overlap_add_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/overlap_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,5 +147,6 @@ PD_REGISTER_KERNEL(overlap_add,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 10 additions & 6 deletions paddle/phi/kernels/gpu/sign_kernel.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ limitations under the License. */
#include "paddle/phi/kernels/impl/sign_kernel_impl.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/phi/common/float16.h"

using float16 = phi::dtype::float16;

PD_REGISTER_KERNEL(
sign, GPU, ALL_LAYOUT, phi::SignKernel, float, double, float16) {}
#include "paddle/phi/common/amp_type_traits.h"

PD_REGISTER_KERNEL(sign,
GPU,
ALL_LAYOUT,
phi::SignKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
57 changes: 55 additions & 2 deletions python/paddle/fluid/tests/unittests/test_overlap_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core


def overlap_add(x, hop_length, axis=-1):
assert axis in [0, -1], 'axis should be 0/-1.'
assert len(x.shape) >= 2, 'Input dims shoulb be >= 2.'
assert len(x.shape) >= 2, 'Input dims should be >= 2.'

squeeze_output = False
if len(x.shape) == 2:
Expand Down Expand Up @@ -101,6 +102,58 @@ def test_check_grad_normal(self):
paddle.disable_static()


class TestOverlapAddFP16Op(TestOverlapAddOp):
def initTestCase(self):
input_shape = (50, 3)
input_type = 'float16'
attrs = {
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestOverlapAddBF16Op(OpTest):
def setUp(self):
self.op_type = "overlap_add"
self.python_api = paddle.signal.overlap_add
self.shape, self.type, self.attrs = self.initTestCase()
self.np_dtype = np.float32
self.dtype = np.uint16
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.np_dtype),
}
self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}

self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)

def initTestCase(self):
input_shape = (50, 3)
input_type = np.uint16
attrs = {
'hop_length': 4,
'axis': -1,
}
return input_shape, input_type, attrs

def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(self.place)
paddle.disable_static()

def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad_with_place(self.place, ['X'], 'Out')
paddle.disable_static()


class TestCase1(TestOverlapAddOp):
def initTestCase(self):
input_shape = (3, 50)
Expand Down
42 changes: 39 additions & 3 deletions python/paddle/fluid/tests/unittests/test_sign_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import gradient_checker
import numpy as np
from decorator_helper import prog_scope
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle import fluid
Expand All @@ -40,6 +40,42 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestSignFP16Op(TestSignOp):
def setUp(self):
self.op_type = "sign"
self.python_api = paddle.sign
self.inputs = {
'X': np.random.uniform(-10, 10, (10, 10)).astype("float16")
}
self.outputs = {'Out': np.sign(self.inputs['X'])}


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSignBF16Op(OpTest):
def setUp(self):
self.op_type = "sign"
self.python_api = paddle.sign
self.dtype = np.uint16
self.inputs = {
'X': np.random.uniform(-10, 10, (10, 10)).astype("float32")
}
self.outputs = {'Out': np.sign(self.inputs['X'])}

self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')


class TestSignOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
Expand Down Expand Up @@ -97,7 +133,7 @@ def sign_wrapper(self, x):

@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
# the shape of input variable should be clearly specified, not include -1.
eps = 0.005
dtype = np.float32

Expand Down Expand Up @@ -128,7 +164,7 @@ def sign_wrapper(self, x):

@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
# the shape of input variable should be clearly specified, not include -1.
eps = 0.005
dtype = np.float32

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ def overlap_add(x, hop_length, axis=-1, name=None):
out = op(x, *attrs)
else:
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'], op_type
x,
'x',
['int32', 'int64', 'float16', 'float32', 'float64', 'uint16'],
op_type,
)
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3650,7 +3650,7 @@ def sign(x, name=None):
return _C_ops.sign(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'sign'
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'sign'
)
helper = LayerHelper("sign", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down

0 comments on commit be66ac9

Please sign in to comment.