Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 #46111

Merged
merged 24 commits into from
Oct 10, 2022

Conversation

Rayman96
Copy link
Contributor

@Rayman96 Rayman96 commented Sep 16, 2022

PR types

New features

PR changes

OPs

Describe

deformable_conv_v1 算子实现 float16 数据类型支持。
通过benchmark中测试用例,float32与float16前向速度接近更快:

Case No. x_shape offset_shape weight_shape mask_shape data_type Paddle Perf(ms)
1 [6, 512, 19, 34] [6, 18, 19, 34] [256, 512, 3, 3] [6, 9, 19, 34] float32 3.858
2 [6, 512, 19, 34] [6, 18, 19, 34] [256, 512, 3, 3] [6, 9, 19, 34] float16 3.387

后向速度有差距基本一致

Case No. x_shape offset_shape weight_shape mask_shape data_type Paddle Perf(ms)
1 [6, 512, 19, 34] [6, 18, 19, 34] [256, 512, 3, 3] [6, 9, 19, 34] float32 27.9
2 [6, 512, 19, 34] [6, 18, 19, 34] [256, 512, 3, 3] [6, 9, 19, 34] float16 28.0

速度差值主要存在于对dx的导数求解上,其余参数导数求解速度一致。dx求解中使用了CudaAtomicAdd,对于float16的支持较差

38题的速度测试如下:
image

@paddle-bot
Copy link

paddle-bot bot commented Sep 16, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Rayman96

This comment was marked as resolved.

@Rayman96
Copy link
Contributor Author

向大佬请教一个问题,目前实现了float16的支持,但是对于float16的单测如果我单独执行是可以通过的,但如果一个单测文件中既有float32又有float16,就会报错,后执行的数据类型会出错不清楚为什么😂

@@ -43,6 +43,7 @@ inline void ModulatedDeformableCol2imCPUKernel(
const int height_col,
const int width_col,
T* grad_im) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要求支持CPU的FP16 Kernel,因此这个文件暂时不要修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的函数由于cc和cu文件共用,就在cpu的文件中增加了少量代码以适配。

@zhangting2020
Copy link
Contributor

Details

贴一下报错日志是什么

@zhangting2020
Copy link
Contributor

这道2🌟题目是要求要完成FP16的性能优化的,基本的要求是FP16性能优于FP32。

@Rayman96
Copy link
Contributor Author

就是提示梯度误差超过阈值,如果单独测fp16或者单独测fp32就不会

@Rayman96
Copy link
Contributor Author

这道2🌟题目是要求要完成FP16的性能优化的,基本的要求是FP16性能优于FP32。

现在有三个反向kernel,一个正向kernel。其中正向的和两个反向的都是和fp32速度基本一致,一个使用了CudaAtomicAdd的kernel速度较慢。也想求助下这个有没有合适的替代方法

@Rayman96
Copy link
Contributor Author

这道2🌟题目是要求要完成FP16的性能优化的,基本的要求是FP16性能优于FP32。

之前因为上述单测的问题,我一直以为fp16计算精度没有达到要求,就将绝大部分转换成了fp32计算。我再优化一下应该可以将目前一致的速度提高一些

@Rayman96 Rayman96 changed the title 【Hackathon No.56】deformable_conv_v1 算子实现 float16 数据类型支持 【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 Sep 20, 2022
@Rayman96 Rayman96 changed the title 【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 【Hackathon No.56】deformable_conv_v1 算子实现 float16 数据类型支持 Sep 20, 2022
@Rayman96 Rayman96 changed the title 【Hackathon No.56】deformable_conv_v1 算子实现 float16 数据类型支持 【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 Sep 20, 2022
@Rayman96
Copy link
Contributor Author

这道2🌟题目是要求要完成FP16的性能优化的,基本的要求是FP16性能优于FP32。

@zhangting2020 已优化完成,float16没有比float32慢的kernel了,前向速度更优,后向速度一致。具体数值更新在了最开始的表格里

@Rayman96
Copy link
Contributor Author

Rayman96 commented Sep 21, 2022

image

benchmark里结果是这样的,paddle实现的代码应该是准确的,但是benchmark用来对照的好像是有误的。

image

CI流程中也都给出的是0.02441这个值

@@ -20,6 +20,8 @@
from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard

paddle.enable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加这里的目的是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加这里的目的是?

这里不加的话单测文件会执行报错
image

@@ -43,6 +43,7 @@ inline void ModulatedDeformableCol2imCPUKernel(
const int height_col,
const int width_col,
T* grad_im) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

猜测有没有一种可能,单测出错的原因是因为你这里使用了MT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

猜测有没有一种可能,单测出错的原因是因为你这里使用了MT?

好像有可能🤔 我在cc文件里不用MT试试

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

猜测有没有一种可能,单测出错的原因是因为你这里使用了MT?

仔细看了下单测的过程,应该不是由于kernel中代码的问题。是因为op_test内check_grad使用的是(y_neg-y_pos) / delta /2的数值方式,而y_neg和y_pos本身是float16的小数位不够,精度不够高,导致分子无法表示真实差值。感觉在float16类型是可能用数值定义法来测试好像不太行😂。
目前我是通过人工核验float16和float32情况下的grad一致性,因为float32是可以利用数值法核验的。
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前我是通过人工核验float16和float32情况下的grad一致性,因为float32是可以利用数值法核验的。

「人工校验」的方式,是说和python实现的reference版本结果进行对比?确保如下2个方面:
(1)np数据需要显式定义数据类型,默认是double
(2)reference实现也务必使用float16作为输入输出、但是使用float类型进行计算,保证reference实现和cuda kernel中实现的方式是一致。

另外,conv类算子可能本身误差较大,我看fp32的单测,max_relative_error已经设置到0.05、0.1这么大了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前我是通过人工核验float16和float32情况下的grad一致性,因为float32是可以利用数值法核验的。

「人工校验」的方式,是说和python实现的reference版本结果进行对比?确保如下2个方面: (1)np数据需要显式定义数据类型,默认是double (2)reference实现也务必使用float16作为输入输出、但是使用float类型进行计算,保证reference实现和cuda kernel中实现的方式是一致。

另外,conv类算子可能本身误差较大,我看fp32的单测,max_relative_error已经设置到0.05、0.1这么大了。

不是的,是直接使用c++版本的float32和float16二者的grad打印出来进行比较,可以参考这里的截图。https://github.com/PaddlePaddle/Paddle/pull/46111#issuecomment-1253283724。
二者一致应该可以在人工上确认我fp16的实现上应该是没问题的。
ps:np目前python版只有前向infer的实现,fp32和fp16是没有问题的。目前单测的后向计算grad都是统一使用op_test.py中的数值定义法(本身存在精度不够和误差大的问题),这也是造成grad无法通过单测的原因。

@@ -284,10 +286,25 @@ def init_type(self):
self.dtype = np.float64


class TestWithFloat16(TestModulatedDeformableConvOp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unittest.skipIf(not core.is_compiled_with_cuda(),
                       "core is not compiled with CUDA")

FP16只有GPU支持,可以在单测前加如上装饰器。

Copy link
Contributor Author

@Rayman96 Rayman96 Oct 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unittest.skipIf(not core.is_compiled_with_cuda(),
                       "core is not compiled with CUDA")

FP16只有GPU支持,可以在单测前加如上装饰器。

这里增加装饰器后还是会出错,根源应该还是在check_grad的算法(y_neg-y_pos) / delta /2的问题。我参考python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py中的做法,将numeric_grad_delta扩大,这样可以扩大分子,但是会降低grad数值的精度。这样做的话Input和Filter的精度在单测中可以通过,但是Offset的波动太大,还是不行。
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我参考python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py中的做法,将numeric_grad_delta扩大,这样可以扩大分子,但是会降低grad数值的精度。这样做的话Input和Filter的精度在单测中可以通过,但是Offset的波动太大,还是不行。

float16本身的有效位数只有3位,所以atol、rtol、max_relative_error设置成1e-3都是合理的。你看下单测中需要放大到多少?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我参考python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py中的做法,将numeric_grad_delta扩大,这样可以扩大分子,但是会降低grad数值的精度。这样做的话Input和Filter的精度在单测中可以通过,但是Offset的波动太大,还是不行。

float16本身的有效位数只有3位,所以atol、rtol、max_relative_error设置成1e-3都是合理的。你看下单测中需要放大到多少?

我观察我实现的结果本身是可以满足1e-3的数值的,但是op_test文件始终产生不了这样精度的check_grad值,所以一直通不过

@Rayman96
Copy link
Contributor Author

Rayman96 commented Oct 8, 2022

image

benchmark里结果是这样的,paddle实现的代码应该是准确的,但是benchmark用来对照的好像是有误的。

image

CI流程中也都给出的是0.02441这个值

@Xreki 这里的截图

max_relative_error=2e-1,
no_grad_set=set(['Filter']),
check_eager=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议
可以按照(2)中的提示直接写一个和fp32结果比较的单测,不需要(1)。
如果时间来不及,那么就先按照(1)将这个PR的单测问题解决尽量赶在今天合入,之后再按照(2)补一个PR。

(1)这个PR可以暂时先不用check_grad去检查梯度。使用skip_check_grad_ci装饰器暂时跳过梯度检查。参考如下单测,写上原因:

@skip_check_grad_ci(
reason="The function 'check_grad' for large inputs is too slow.")
class TestConcatOp3(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((1, 256, 170, 256)).astype(self.dtype)
self.x1 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.x2 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.axis = 1
def test_check_grad(self):
pass

(2)单测中与fp32的结果进行比较**。你可以尝试:

  • TestWithFloat16这个test继承unittest.TestCase
  • 单测中分别计算fp32和fp16的前、反向结果,并和fp32的计算结果进行精度的比较

可以参考test_layer_norm_op.py:

class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()
weight_np = weight_np.astype(dtype)
bias_np = bias_np.astype(dtype)
x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
y_np = y.numpy().astype('float32')
x_g_np = x_g.numpy().astype('float32')
w_g_np = w_g.numpy().astype('float16')
b_g_np = b_g.numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
x_np = np.random.random([10, 20]).astype('float16')
weight_np = np.random.random([20]).astype('float16')
bias_np = np.random.random([20]).astype('float16')
y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float16')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'float32')
def assert_equal(x, y):
np.testing.assert_array_equal(x, y)
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议: 可以按照(2)中的提示直接写一个和fp32结果比较的单测,不需要(1)。 如果时间来不及,那么就先按照(1)将这个PR的单测问题解决尽量赶在今天合入,之后再按照(2)补一个PR。

(1)这个PR可以暂时先不用check_grad去检查梯度。使用skip_check_grad_ci装饰器暂时跳过梯度检查。参考如下单测,写上原因:

@skip_check_grad_ci(
reason="The function 'check_grad' for large inputs is too slow.")
class TestConcatOp3(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((1, 256, 170, 256)).astype(self.dtype)
self.x1 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.x2 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.axis = 1
def test_check_grad(self):
pass

(2)单测中与fp32的结果进行比较**。你可以尝试:

  • TestWithFloat16这个test继承unittest.TestCase
  • 单测中分别计算fp32和fp16的前、反向结果,并和fp32的计算结果进行精度的比较

可以参考test_layer_norm_op.py:

class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()
weight_np = weight_np.astype(dtype)
bias_np = bias_np.astype(dtype)
x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
y_np = y.numpy().astype('float32')
x_g_np = x_g.numpy().astype('float32')
w_g_np = w_g.numpy().astype('float16')
b_g_np = b_g.numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
x_np = np.random.random([10, 20]).astype('float16')
weight_np = np.random.random([20]).astype('float16')
bias_np = np.random.random([20]).astype('float16')
y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float16')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'float32')
def assert_equal(x, y):
np.testing.assert_array_equal(x, y)
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)

好的 我尽快按照(1)中提交看下CI结果。如果下午6点前(2)还来不及的话我后边再补一个

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照(2)实现的测试方法提交后CI都全部通过了,Output和grad都没问题。我再改一下CodeStyle重新触发流水线,晚上应该就可以全部跑好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照(2)实现后测试 CI已全部通过

Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
// using MT = typename phi::dtype::MPTypeTrait<T>::Type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行注释可以另提一个PR删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行注释可以另提一个PR删除

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行注释可以另提一个PR删除

提交了已通过CI,麻烦审核 #46827

@zhangting2020 zhangting2020 merged commit 5e0614a into PaddlePaddle:develop Oct 10, 2022
@Xreki
Copy link
Contributor

Xreki commented Oct 10, 2022

image

deformable_conv前向性能提升34%,符合黑客松算子优化验收标准。

因OP Benchmark系统中默认只有1个测试配置,建议可以用https://github.com/PaddlePaddle/benchmark/blob/master/api/tests_v2/model_configs/deformable_conv.json 中更多配置验证下性能提升效果。

@Rayman96
Copy link
Contributor Author

image

deformable_conv前向性能提升34%,符合黑客松算子优化验收标准。

因OP Benchmark系统中默认只有1个测试配置,建议可以用https://github.com/PaddlePaddle/benchmark/blob/master/api/tests_v2/model_configs/deformable_conv.json 中更多配置验证下性能提升效果。

好的 我后边测试完后再贴一下结果

zhangting2020 added a commit to zhangting2020/Paddle that referenced this pull request Oct 12, 2022
@zhangting2020 zhangting2020 mentioned this pull request Oct 12, 2022
@zhangting2020
Copy link
Contributor

目前此PR引起了ppyolov2模型运行失败,报错如下:
84fc2b8478d28629883a1f41d8d41719

可能与下面的修改有关,在fp16数据类型下,计算过程可以使用fp32,但是输入输出需要保持fp16精度,pr中修改了dx的数据类型。
image

@Rayman96
Copy link
Contributor Author

目前此PR引起了ppyolov2模型运行失败,报错如下: 84fc2b8478d28629883a1f41d8d41719

可能与下面的修改有关,在fp16数据类型下,计算过程可以使用fp32,但是输入输出需要保持fp16精度,pr中修改了dx的数据类型。 image

好的我修改一下,这里确实dx输出没有保持fp16。修改后我是重新开一个PR吗?

@Rayman96 Rayman96 mentioned this pull request Oct 13, 2022
@Rayman96
Copy link
Contributor Author

Rayman96 commented Oct 13, 2022

zhangting2020 pushed a commit that referenced this pull request Oct 13, 2022
* Revert "【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 (#46111)"
@Rayman96
Copy link
Contributor Author

根据建议将前向加速的代码拆出来先提交了一个单独的PR,CI已全部通过,幸苦帮忙审核合入
#46997

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants