Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify the unsqueeze dimension of input data in conv1d NCL And NLC format #38425

Merged
merged 11 commits into from
Feb 10, 2022

Conversation

Zjq9409
Copy link
Contributor

@Zjq9409 Zjq9409 commented Dec 24, 2021

PR types

Performance optimization

PR changes

OPs

Describe

对于有些配置paddle_dev分支与竞品性能差距过大,对NCL和NLC格式输入数据进行优化,主要优化点有:

  1. 修改维度扩充方式

    • 对于NCL格式input从原始padde_dev对[-1]维度进行扩充,修改为对[-2]进行维度扩充,如下:
      image

    • 对于NLC格式input从原始paddle_dev对[-2]维度做扩充修改为[-3]维度做扩充,如下:
      image

    • 不论格式,weight对[-2]维度扩充,如:[80,80,5] ----> [80,80,1,5]

    • 不论格式,padding/stride/dilation对[0]维度做扩充;

  2. 当在GPU上编译时,并且满足groups==num_channels and num_filters% num_channels == 0,底层执行depthwise_conv2d方法

经过优化后,前向性能数据如下:

case 竞品 优化前 优化前相比竞品 优化后 优化后相比竞品 加速比
w:[64,1,1] x:[6,1,25500] format:NCL dilation:1 groups1 padding:0 stride:1 0.07823 0.16254 差于 (1.06x) 0.08077 打平 (2.42%) 2.01
w:[80,80,5] x:[6,80,89] format:NCL dilation:1 groups:1 padding:0 stride:1 0.02728 0.05356 差于 (96.12%) 0.02730 打平 (0.11%) 1.96
w:[128,64,3] x:[6,64,25500] format:NCL dilation:2 groups:1 padding:2 stride:1 0.65863 0.89356 差于 (34.58%) 0.65985 打平 (0.53%) 1.35
w:[128,64,3] x:[6,64,25500] format:NCL dilation:64 groups:1 padding:64 stride:1 0.94517 0.91143 打平 (3.73%) 0.91976 打平 (3.36%) 0.99
w:[256,256,1] x:[16,256,36] format:NCL dilation:1 groups:1 padding:0 stride:1 0.03781 0.03768 打平 (0.11%) 0.03767 打平 (0.03%) 1.00
w:[256,1,15] x:[16,256,40] format:NCL dilation:1 groups:256 padding:7 stride:1 0.00667 0.01218 差于 (81.79%) 0.00810 差于 (20.90%) 1.50
w:[512,256,1] x:[16,256,43] format:NCL dilation:1 groups:1 padding:0 stride:1 0.03847 0.03823 打平 (0.10%) 0.03805 打平 (0.13%) 1.00
w:[1024,64,3] x:[6,64,25500] format:NCL dilation:4 groups:1 padding:4 stride:1 5.30080 6.34554 差于 (19.71%) 5.23961 打平 (1.01%) 1.21
w:[1024,64,3] x:[6,64,25500] format:NCL dilation:8 groups:1 padding:8 stride:1 5.35025 6.38406 差于 (19.32%) 5.25906 打平 (0.97%) 1.21

反向性能数据如下:

case 竞品 优化前 优化前相比竞品 优化后 优化后相比竞品 加速比
w:[64,1,1] x:[6,1,25500] format:NCL dilation:1 groups1 padding:0 stride:1 0.25707 174.04358 差于(676.03x) 0.36675 差于 (42.67%) 474.56
w:[80,80,5] x:[6,80,89] format:NCL dilation:1 groups:1 padding:0 stride:1 0.04760 0.08536 差于 (90.37%) 0.05322 差于 (11.81%) 1.60
w:[128,64,3] x:[6,64,25500] format:NCL dilation:2 groups:1 padding:2 stride:1 1.27434 211.91023 差于(165.29x) 1.53514 差于 (20.47%) 138.04
w:[128,64,3] x:[6,64,25500] format:NCL dilation:64 groups:1 padding:64 stride:1 1.35295 2.69687 差于 (99.20%) 1.63474 差于 (20.83%) 1.65
w:[256,256,1] x:[16,256,36] format:NCL dilation:1 groups:1 padding:0 stride:1 0.05514 0.08157 差于 (49.07%) 0.07677 差于 (39.23%) 1.06
w:[256,1,15] x:[16,256,40] format:NCL dilation:1 groups:256 padding:7 stride:1 0.04213 0.03405 优于 (19.14%) 0.05021 差于 (19.18%) 0.68
w:[512,256,1] x:[16,256,43] format:NCL dilation:1 groups:1 padding:0 stride:1 0.14101 0.14736 打平 (4.72%) 0.14632 打平 (3.77%) 1.01
w:[1024,64,3] x:[6,64,25500] format:NCL dilation:4 groups:1 padding:4 stride:1 9.58404 11.75670 差于 (22.67%) 10.90966 差于 (15.33%) 1.08
w:[1024,64,3] x:[6,64,25500] format:NCL dilation:8 groups:1 padding:8 stride:1 9.45516 11.83995 差于 (25.22%) 10.91164 差于 (14.80%) 1.09

w:[256,1,15] x:[16,256,40] format:NCL dilation:1 groups:256 padding:7 stride:1 配置性能下降原因是,满足(2)这个条件,底层执行了depthwise_conv2d方法,而depthwise_conv2d方法,paddle相差竞品2.21x,后续优化该算子,性能能持平

注:以上数据在V100,32G, cudnn 7上测试

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Zjq9409 Zjq9409 changed the title optimize conv1d forward optimize conv1d op Jan 6, 2022
@Zjq9409 Zjq9409 changed the title optimize conv1d op Modify the unsqueeze dimension of input data in NCL format Jan 10, 2022
@Zjq9409 Zjq9409 changed the title Modify the unsqueeze dimension of input data in NCL format Modify the unsqueeze dimension of input data in conv1d NCL format Jan 10, 2022
@Zjq9409 Zjq9409 force-pushed the cudnn_conv_opt branch 3 times, most recently from e040896 to 3cc8958 Compare January 10, 2022 06:19
@Zjq9409 Zjq9409 changed the title Modify the unsqueeze dimension of input data in conv1d NCL format Modify the unsqueeze dimension of input data in conv1d NCL And NLC format Jan 13, 2022
python/paddle/nn/functional/conv.py Outdated Show resolved Hide resolved
python/paddle/nn/functional/conv.py Show resolved Hide resolved
@ZzSean
Copy link
Contributor

ZzSean commented Jan 17, 2022

建议把不同优化方法对应的case标记出来,或者把case的详细config贴上来

@Zjq9409
Copy link
Contributor Author

Zjq9409 commented Jan 17, 2022

建议把不同优化方法对应的case标记出来,或者把case的详细config贴上来

已经附上了每个case详细的config

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当在GPU上编译时,并且满足一定条件,底层执行depthwise_conv2d方法,如下conv18配置
满足什么条件,希望能表述清晰

python/paddle/nn/functional/conv.py Outdated Show resolved Hide resolved
@Zjq9409
Copy link
Contributor Author

Zjq9409 commented Jan 18, 2022

当在GPU上编译时,并且满足一定条件,底层执行depthwise_conv2d方法,如下conv18配置 满足什么条件,希望能表述清晰

具体条件已经加上

wangxinxin08
wangxinxin08 previously approved these changes Jan 20, 2022
Copy link
Contributor

@wangxinxin08 wangxinxin08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以在设置num_filters=512的位置记一个TODO

python/paddle/nn/functional/conv.py Outdated Show resolved Hide resolved
ZzSean
ZzSean previously approved these changes Jan 26, 2022
Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,对于上述讨论的会影响某些case的问题,可以展示下对相关模型性能的影响情况,若下降影响程度不大,而提升效果明显的话,可以先合入,待后续自动配置优化完善后,该问题应该可以解决。

Copy link
Contributor

@wangxinxin08 wangxinxin08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit 224bc51 into PaddlePaddle:develop Feb 10, 2022
Shixiaowei02 added a commit that referenced this pull request Feb 16, 2022
* 【Pten】Adjust the Empyt dev_api (#39143)

* adjust the Empyt dev_api

* fix merge conflict

* fix sparse_utils_kernel

* Fix code conflict of empty dev_api (#39430)

* fix code conflict

* clear cache

* just try

* [PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)

* [PTen] Add standard kernel suffix set (#39404)

* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name

* [pten] update isnan registration (#39419)

* update isnan registration

* fix compile

* [bf16] add bf16 kernel: dropout & reshape & slice (#39395)

* add dropout

* add reshape

* add slice

* refien slice unittest

* refine slice unittest

* add cpu bf16 kernel

* [bf16] add bf16 kernel: squeeze & unsqueeze & stack (#39402)

* add squeeze unsqueeze stack

* add unittest

* add cpu kernel

* Modify the unsqueeze dimension of input data in conv1d NCL And NLC format (#38425)

* optimize conv1d forward

* add conv opt

* Optimize memory copy

* delete share data with

* set num_filters=512

* add nlc optimize

* Optimize num_filter=512 data on A100 and V100

* Fix the workspace_size size setting of filter

* 【Pten】Refactor C++ API code-gen (#39408)

* refactor C++ API code-gen

* fix windows problem of C++ API

* Refactored Python-C Attributes Parsing Functions (#39328)

* Add _get_parameter method to Lamb optimizer (#39416)

* add _get_parameter func to lamb

* remove duplicate code

* mkldnn layout issue fix (#39422)

* mkldnn conv fix

* definetion

* fix compile error on jetson (#39441)

* move Masked select to pten (#39193)

* move masked select cpu kernel

* add masked selected gpu kernel; test=develop

* fix bugs; test=develop

* bug fix; test=develop

* bug fix; test=develop

* add namespace to set mask array; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* fix ddim bug; test=develop

* fix npu op bug; test=develop

* fix xpu dependecy bug; test=develop

* move kernel args to sig.cc; test=develop

* 【PaddlePaddle Hackathon】31. Add Java frontend for Paddle Inference  (#37162)

* fix check error of ResetHolder (#39439)

* Added python-c code generation for final state Eager Dygraph (#39233)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Fixed issues from merge

* Fixed merge issues

* change dtype of pooling mask to 'int32' for Paddle2ONNX (#39314)

* change dtype of pooling mask to 'int32' for Paddle2ONNX

* empty commit to rerun ci

* fix format

* share MemOptVarInfos of external variables into cinn_launch subgraph (#39209)

* add a graph pass to share MemOptVarInfos of external variables into subgraph

* update pass name

* fix compile failed

* add share_mem_opt_info_to_subgraph_pass test

* share_mem_opt_info_to_subgraph_pass_test pass

* modify some codes for better style and more robust

* update cmake

* [NPU] add reduce_min (#39019)

[NPU] add reduce_min

* [MLU] add mlu kernel for accuracy op (#39337)

* [MLU] add mlu kernel for accuracy op

* fix license format

* fix error message

* [Dy2St]Handle `a, b = paddle.shape(x)` in Static Analysis (#39245)

* refine Assign

* add UT

* 【Pten】Auto-Generate InterMeta register (#39436)

* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code

* Support different dtypes of inputs for elementwise ops (#38859)

* improve backward performance

* support different dtypes for elementwise ops

* Add profiler node tree implementation (#39316)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* add print pten kernel tool (#39371)

* test=document_fix;add print pten kernel tool

* test=document_fix

* test=document_fix

* test=document_fix

* test=document_fix

* add print_pten_kernels tool

* add print_pten_kernels tool

* fix windows complie

* notest,test=rocm_ci

* add merge tool

* add comments

* [new-exec] set type of op-kernel op by place (#39458)

* Add log for executor (#39459)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173.

* add log for Executor

Co-authored-by: liutiexing <liutiexing@google.com>

* [Paddle Inference] support ernie quant model with interleaved (#39424)

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* 统一 ps 开发 - python (#39431)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok

Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>

* [PTen] Move grad GetExpectedPtenKernelArgs into pten (#39418)

* move grad get expected pten kernel args

* fix reduce sum error

* fix element_sub_grad failed

* revert kernel judge change

* fix compilation warning on mac (#39438)

* get build time (#39368)

* fix prelu trt convert (#39389)

* Optimize bilinear interpolation foward (#39243)

* bilinear_fw init

* optimize code

* pre-compute linear_interp input index

* Optimize performance of softmax_bwd when axis!=-1 (#38609)

* Optimize performance of softmax_bwd when axis!=-1

* fix

* fix

* fix

* fix

* [PTen] Remove pten core's dependency on fluid xxx_info.h (#39401)

* ermove xxx_info include

* fix namespace error

* resolve conflict

* skip xpu context in registry

* fix macro error

* resolve conflict

* resolve conflict

* revert xpu convert

* remove trans to fluid place

* remove useless headers

* [Pten] move operators/math/math_function_* to pten/kernels/func (#39300)

* move operators/math/math_function_* to pten/kernels/func
* namespace from `paddle::operators::math` to `pten::funcs`

* [MLU] add pool2d and pool2d_grad mlu kernel (#39453)

* [MLU]support c_gen_cncl_id_op run on MLU device (#39336)

Co-authored-by: zhangna <zhangna@cambricon.com>

* [bf16] add bf16 kernel: transpose & unbind (#39457)

* add transpose unbind

* add unittest

* refine transpose unittest

* uniform_random op for mlu (#39450)

* [MLU] add pool2d pytest (#39454)

* Added shape (U)INT8/BF16/FP32 oneDNN kernel (#36033)

* added shape oneDNN kernel

* removed unnecessary import from test

* added skipping tests for GPU

* refactoring

* refactored shape kernel

* added tests in new framework

* removed one line

* minor change

* added newline at EOF

* added formatting

* added attributes as extra

* move memcpy.h into cc file (#39469)

* Add TensorRT inspector into Paddle-TRT (#38362)

* Fix add profiler node tree implementation cmake error (#39474)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* fix dependency error

* unify naming style (#39481)

* [Pten] Generate Wrapped InferMeta by Yaml (#39482)

* generate wrapped_infer_meta

* add test for wrapped_infer_meta

* Update test_meta_fn_utils.cc

* change the dir of generated file

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>

* Adjusted python-level trace_op to accomodate final state Eager Dygraph (#39319)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Adjusted python-level trace_op to accomodate final state Eager Dygraph

* Added Logs for final state Eager Dygraph

* Fixed merge issues

* Fixed minor issue

* Fixed get_tensor method for EagerTensor (#39414)

* Enabled Eager OpTest #1

* Enabled Eager OpTest #1

* Fixed get_tensor method for EagerTensor

* [Approver Update] update check approver of qili93, test=document_fix (#39483)

* [MLU] add mlu kernel for c_broadcast op (#39470)

* update xpu test build script and fix get_test_cover_info, *test=kunlun (#39235)

* fix gather_nd, *test=kunlun (#39283)

* [pten] add split kernel (#39060)

* add split kernel

* add split kernel signature

* fix split bug

* modify MakePtenScalarArrayFromVarList

* modify MakePtenScalarArrayFromVarList

* fix split windows register error

* add test case for split kernel

* replace raw split kernel with pten kernel

* fix makeScalar/ScalarArray bug

* remove debug log

* remove int64_t type in buildPtcontext

* update by code review

* fix split dev test failed

* change DenseTensorMeta to MetaTensor

* change split api code from auto gen to manual

* split cuda kernel support bfloat16 type

* fix conflict

* rm raw split kernel

* merge develop branch

* change to pten::errors

* new may of test cases, *test=kunlun (#39444)

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* [PTen] Add HasAttr for ArgumentMappingContext (#39464)

* add has_attr for arg map context

* skip useless attr now

* skip attr if not exists

* fix typo

* [ROCm] fix missing dcu kernel in operator.cmake, test=develop (#39480)

Co-authored-by: zyfncg <zhangyunfei07@baidu.com>
Co-authored-by: Aganlengzi <aganlengzi@gmail.com>
Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Leo Chen <chenqiuliang@baidu.com>
Co-authored-by: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com>
Co-authored-by: crystal <62974595+Zjq9409@users.noreply.github.com>
Co-authored-by: Zhanlue Yang <jim19930609@gmail.com>
Co-authored-by: sneaxiy <32832641+sneaxiy@users.noreply.github.com>
Co-authored-by: wenbin <wang3323032@qq.com>
Co-authored-by: Wilber <jiweibo@baidu.com>
Co-authored-by: hong <43953930+phlrain@users.noreply.github.com>
Co-authored-by: chenyanlann <62465397+chenyanlann@users.noreply.github.com>
Co-authored-by: Wei Shengyu <weisy11@163.com>
Co-authored-by: TeFeng Chen <ctfeng66@163.com>
Co-authored-by: furnace <34057289+windstamp@users.noreply.github.com>
Co-authored-by: fwenguang <95677191+fwenguang@users.noreply.github.com>
Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: Zhang Ting <zhangting_2017@163.com>
Co-authored-by: chenjian <chenjian26@baidu.com>
Co-authored-by: Shang Zhizhou <shangzhizhou@baidu.com>
Co-authored-by: liutiexing <74819124+liutiexing@users.noreply.github.com>
Co-authored-by: liutiexing <liutiexing@google.com>
Co-authored-by: Wangzheee <634486483@qq.com>
Co-authored-by: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com>
Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>
Co-authored-by: zhangchunle <clzhang_cauc@163.com>
Co-authored-by: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: Lijunhui <1578034415@qq.com>
Co-authored-by: Zhang Zheng <32410583+ZzSean@users.noreply.github.com>
Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
Co-authored-by: zn <96479180+kangna-qi@users.noreply.github.com>
Co-authored-by: zhangna <zhangna@cambricon.com>
Co-authored-by: joeqiao12 <45232181+joeqiao12@users.noreply.github.com>
Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Leo Chen <39020268+leo0519@users.noreply.github.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>
Co-authored-by: Qi Li <qili93@qq.com>
Co-authored-by: maxhuiy <1508399706@qq.com>
Co-authored-by: TTerror <tangzhiyi11@users.noreply.github.com>
Co-authored-by: chentianyu03 <chentianyu03@baidu.com>
Co-authored-by: helen88 <z8hanghuan@126.com>
winter-wang pushed a commit to winter-wang/Paddle that referenced this pull request Feb 16, 2022
* 【Pten】Adjust the Empyt dev_api (PaddlePaddle#39143)

* adjust the Empyt dev_api

* fix merge conflict

* fix sparse_utils_kernel

* Fix code conflict of empty dev_api (PaddlePaddle#39430)

* fix code conflict

* clear cache

* just try

* [PluggableDevice] custom kernel supports multi cpp_dtype registering (PaddlePaddle#39385)

* [PTen] Add standard kernel suffix set (PaddlePaddle#39404)

* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name

* [pten] update isnan registration (PaddlePaddle#39419)

* update isnan registration

* fix compile

* [bf16] add bf16 kernel: dropout & reshape & slice (PaddlePaddle#39395)

* add dropout

* add reshape

* add slice

* refien slice unittest

* refine slice unittest

* add cpu bf16 kernel

* [bf16] add bf16 kernel: squeeze & unsqueeze & stack (PaddlePaddle#39402)

* add squeeze unsqueeze stack

* add unittest

* add cpu kernel

* Modify the unsqueeze dimension of input data in conv1d NCL And NLC format (PaddlePaddle#38425)

* optimize conv1d forward

* add conv opt

* Optimize memory copy

* delete share data with

* set num_filters=512

* add nlc optimize

* Optimize num_filter=512 data on A100 and V100

* Fix the workspace_size size setting of filter

* 【Pten】Refactor C++ API code-gen (PaddlePaddle#39408)

* refactor C++ API code-gen

* fix windows problem of C++ API

* Refactored Python-C Attributes Parsing Functions (PaddlePaddle#39328)

* Add _get_parameter method to Lamb optimizer (PaddlePaddle#39416)

* add _get_parameter func to lamb

* remove duplicate code

* mkldnn layout issue fix (PaddlePaddle#39422)

* mkldnn conv fix

* definetion

* fix compile error on jetson (PaddlePaddle#39441)

* move Masked select to pten (PaddlePaddle#39193)

* move masked select cpu kernel

* add masked selected gpu kernel; test=develop

* fix bugs; test=develop

* bug fix; test=develop

* bug fix; test=develop

* add namespace to set mask array; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* fix ddim bug; test=develop

* fix npu op bug; test=develop

* fix xpu dependecy bug; test=develop

* move kernel args to sig.cc; test=develop

* 【PaddlePaddle Hackathon】31. Add Java frontend for Paddle Inference  (PaddlePaddle#37162)

* fix check error of ResetHolder (PaddlePaddle#39439)

* Added python-c code generation for final state Eager Dygraph (PaddlePaddle#39233)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Fixed issues from merge

* Fixed merge issues

* change dtype of pooling mask to 'int32' for Paddle2ONNX (PaddlePaddle#39314)

* change dtype of pooling mask to 'int32' for Paddle2ONNX

* empty commit to rerun ci

* fix format

* share MemOptVarInfos of external variables into cinn_launch subgraph (PaddlePaddle#39209)

* add a graph pass to share MemOptVarInfos of external variables into subgraph

* update pass name

* fix compile failed

* add share_mem_opt_info_to_subgraph_pass test

* share_mem_opt_info_to_subgraph_pass_test pass

* modify some codes for better style and more robust

* update cmake

* [NPU] add reduce_min (PaddlePaddle#39019)

[NPU] add reduce_min

* [MLU] add mlu kernel for accuracy op (PaddlePaddle#39337)

* [MLU] add mlu kernel for accuracy op

* fix license format

* fix error message

* [Dy2St]Handle `a, b = paddle.shape(x)` in Static Analysis (PaddlePaddle#39245)

* refine Assign

* add UT

* 【Pten】Auto-Generate InterMeta register (PaddlePaddle#39436)

* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code

* Support different dtypes of inputs for elementwise ops (PaddlePaddle#38859)

* improve backward performance

* support different dtypes for elementwise ops

* Add profiler node tree implementation (PaddlePaddle#39316)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* add print pten kernel tool (PaddlePaddle#39371)

* test=document_fix;add print pten kernel tool

* test=document_fix

* test=document_fix

* test=document_fix

* test=document_fix

* add print_pten_kernels tool

* add print_pten_kernels tool

* fix windows complie

* notest,test=rocm_ci

* add merge tool

* add comments

* [new-exec] set type of op-kernel op by place (PaddlePaddle#39458)

* Add log for executor (PaddlePaddle#39459)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173.

* add log for Executor

Co-authored-by: liutiexing <liutiexing@google.com>

* [Paddle Inference] support ernie quant model with interleaved (PaddlePaddle#39424)

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* 统一 ps 开发 - python (PaddlePaddle#39431)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok

Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>

* [PTen] Move grad GetExpectedPtenKernelArgs into pten (PaddlePaddle#39418)

* move grad get expected pten kernel args

* fix reduce sum error

* fix element_sub_grad failed

* revert kernel judge change

* fix compilation warning on mac (PaddlePaddle#39438)

* get build time (PaddlePaddle#39368)

* fix prelu trt convert (PaddlePaddle#39389)

* Optimize bilinear interpolation foward (PaddlePaddle#39243)

* bilinear_fw init

* optimize code

* pre-compute linear_interp input index

* Optimize performance of softmax_bwd when axis!=-1 (PaddlePaddle#38609)

* Optimize performance of softmax_bwd when axis!=-1

* fix

* fix

* fix

* fix

* [PTen] Remove pten core's dependency on fluid xxx_info.h (PaddlePaddle#39401)

* ermove xxx_info include

* fix namespace error

* resolve conflict

* skip xpu context in registry

* fix macro error

* resolve conflict

* resolve conflict

* revert xpu convert

* remove trans to fluid place

* remove useless headers

* [Pten] move operators/math/math_function_* to pten/kernels/func (PaddlePaddle#39300)

* move operators/math/math_function_* to pten/kernels/func
* namespace from `paddle::operators::math` to `pten::funcs`

* [MLU] add pool2d and pool2d_grad mlu kernel (PaddlePaddle#39453)

* [MLU]support c_gen_cncl_id_op run on MLU device (PaddlePaddle#39336)

Co-authored-by: zhangna <zhangna@cambricon.com>

* [bf16] add bf16 kernel: transpose & unbind (PaddlePaddle#39457)

* add transpose unbind

* add unittest

* refine transpose unittest

* uniform_random op for mlu (PaddlePaddle#39450)

* [MLU] add pool2d pytest (PaddlePaddle#39454)

* Added shape (U)INT8/BF16/FP32 oneDNN kernel (PaddlePaddle#36033)

* added shape oneDNN kernel

* removed unnecessary import from test

* added skipping tests for GPU

* refactoring

* refactored shape kernel

* added tests in new framework

* removed one line

* minor change

* added newline at EOF

* added formatting

* added attributes as extra

* move memcpy.h into cc file (PaddlePaddle#39469)

* Add TensorRT inspector into Paddle-TRT (PaddlePaddle#38362)

* Fix add profiler node tree implementation cmake error (PaddlePaddle#39474)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* fix dependency error

* unify naming style (PaddlePaddle#39481)

* [Pten] Generate Wrapped InferMeta by Yaml (PaddlePaddle#39482)

* generate wrapped_infer_meta

* add test for wrapped_infer_meta

* Update test_meta_fn_utils.cc

* change the dir of generated file

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>

* Adjusted python-level trace_op to accomodate final state Eager Dygraph (PaddlePaddle#39319)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Adjusted python-level trace_op to accomodate final state Eager Dygraph

* Added Logs for final state Eager Dygraph

* Fixed merge issues

* Fixed minor issue

* Fixed get_tensor method for EagerTensor (PaddlePaddle#39414)

* Enabled Eager OpTest PaddlePaddle#1

* Enabled Eager OpTest PaddlePaddle#1

* Fixed get_tensor method for EagerTensor

* [Approver Update] update check approver of qili93, test=document_fix (PaddlePaddle#39483)

* [MLU] add mlu kernel for c_broadcast op (PaddlePaddle#39470)

* update xpu test build script and fix get_test_cover_info, *test=kunlun (PaddlePaddle#39235)

* fix gather_nd, *test=kunlun (PaddlePaddle#39283)

* [pten] add split kernel (PaddlePaddle#39060)

* add split kernel

* add split kernel signature

* fix split bug

* modify MakePtenScalarArrayFromVarList

* modify MakePtenScalarArrayFromVarList

* fix split windows register error

* add test case for split kernel

* replace raw split kernel with pten kernel

* fix makeScalar/ScalarArray bug

* remove debug log

* remove int64_t type in buildPtcontext

* update by code review

* fix split dev test failed

* change DenseTensorMeta to MetaTensor

* change split api code from auto gen to manual

* split cuda kernel support bfloat16 type

* fix conflict

* rm raw split kernel

* merge develop branch

* change to pten::errors

* new may of test cases, *test=kunlun (PaddlePaddle#39444)

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* [PTen] Add HasAttr for ArgumentMappingContext (PaddlePaddle#39464)

* add has_attr for arg map context

* skip useless attr now

* skip attr if not exists

* fix typo

* [ROCm] fix missing dcu kernel in operator.cmake, test=develop (PaddlePaddle#39480)

Co-authored-by: zyfncg <zhangyunfei07@baidu.com>
Co-authored-by: Aganlengzi <aganlengzi@gmail.com>
Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Leo Chen <chenqiuliang@baidu.com>
Co-authored-by: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com>
Co-authored-by: crystal <62974595+Zjq9409@users.noreply.github.com>
Co-authored-by: Zhanlue Yang <jim19930609@gmail.com>
Co-authored-by: sneaxiy <32832641+sneaxiy@users.noreply.github.com>
Co-authored-by: wenbin <wang3323032@qq.com>
Co-authored-by: Wilber <jiweibo@baidu.com>
Co-authored-by: hong <43953930+phlrain@users.noreply.github.com>
Co-authored-by: chenyanlann <62465397+chenyanlann@users.noreply.github.com>
Co-authored-by: Wei Shengyu <weisy11@163.com>
Co-authored-by: TeFeng Chen <ctfeng66@163.com>
Co-authored-by: furnace <34057289+windstamp@users.noreply.github.com>
Co-authored-by: fwenguang <95677191+fwenguang@users.noreply.github.com>
Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: Zhang Ting <zhangting_2017@163.com>
Co-authored-by: chenjian <chenjian26@baidu.com>
Co-authored-by: Shang Zhizhou <shangzhizhou@baidu.com>
Co-authored-by: liutiexing <74819124+liutiexing@users.noreply.github.com>
Co-authored-by: liutiexing <liutiexing@google.com>
Co-authored-by: Wangzheee <634486483@qq.com>
Co-authored-by: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com>
Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>
Co-authored-by: zhangchunle <clzhang_cauc@163.com>
Co-authored-by: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: Lijunhui <1578034415@qq.com>
Co-authored-by: Zhang Zheng <32410583+ZzSean@users.noreply.github.com>
Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
Co-authored-by: zn <96479180+kangna-qi@users.noreply.github.com>
Co-authored-by: zhangna <zhangna@cambricon.com>
Co-authored-by: joeqiao12 <45232181+joeqiao12@users.noreply.github.com>
Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Leo Chen <39020268+leo0519@users.noreply.github.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>
Co-authored-by: Qi Li <qili93@qq.com>
Co-authored-by: maxhuiy <1508399706@qq.com>
Co-authored-by: TTerror <tangzhiyi11@users.noreply.github.com>
Co-authored-by: chentianyu03 <chentianyu03@baidu.com>
Co-authored-by: helen88 <z8hanghuan@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants