Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PRIM][PIR]Migrate prim rules #57554

Merged
merged 53 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a815300
fix bugs of generating Op::Build when Op has optional tensor
lxd-cumt Sep 14, 2023
7a838a5
add default constructor for IrMetaTensor
lxd-cumt Sep 14, 2023
812e1e0
fix bugs
lxd-cumt Sep 15, 2023
466033d
polish guard
YuanRisheng Sep 18, 2023
bad4deb
pir support prim gelu and rsqrt
cyber-pioneer Sep 18, 2023
75b6197
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 18, 2023
6d6bacd
support prim bwd ops
cyber-pioneer Sep 18, 2023
be43761
migrate vjp rules of cast,add,multiply,elementwise_pow
lxd-cumt Sep 18, 2023
38e1fba
add cast as primitive op
lxd-cumt Sep 18, 2023
b09d8c9
fix bugs in elementwise_pow_grad
lxd-cumt Sep 18, 2023
0b4ac51
add test for cast_grad
lxd-cumt Sep 18, 2023
61815b6
add test for elementwise_add_grad
lxd-cumt Sep 18, 2023
692ef47
add test for elementwise_mul_grad
lxd-cumt Sep 18, 2023
283c592
add test for elementwise_pow_grad
lxd-cumt Sep 18, 2023
c287053
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 18, 2023
2dbd400
fix bugs
lxd-cumt Sep 18, 2023
c239a92
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 19, 2023
a522d5e
fix bugs
lxd-cumt Sep 19, 2023
3d2172d
support pir prim backward ops
cyber-pioneer Sep 19, 2023
79a57af
Merge commit 'refs/pull/57474/head' of https://github.com/PaddlePaddl…
Charles-hit Sep 19, 2023
ef9a585
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 19, 2023
13cea76
Merge commit 'refs/pull/57456/head' of https://github.com/PaddlePaddl…
lxd-cumt Sep 19, 2023
98549c8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 19, 2023
c0e8cae
refien
zhangbo9674 Sep 19, 2023
7c0ea2d
fix bug
zhangbo9674 Sep 20, 2023
d592830
Merge commit 'refs/pull/57453/head' of https://github.com/PaddlePaddl…
Charles-hit Sep 20, 2023
fe948ce
migrate layer_norm custom vjp rules to pir
lxd-cumt Sep 20, 2023
64a1448
fix bugs in ir_backward
lxd-cumt Sep 20, 2023
1308f41
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 20, 2023
2da91b4
fix backward , scope, and concat_grad prim
Charles-hit Sep 20, 2023
3524882
Merge commit 'refs/pull/57474/head' of https://github.com/PaddlePaddl…
lxd-cumt Sep 20, 2023
dc148fc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Charles-hit Sep 20, 2023
dd8e7e4
add layer_norm fwd decompose logic
lxd-cumt Sep 20, 2023
82c91dd
Merge commit 'refs/pull/57503/head' of https://github.com/PaddlePaddl…
Charles-hit Sep 20, 2023
2b69c73
fix pow
Charles-hit Sep 20, 2023
36e3e7b
Merge commit 'refs/pull/57526/head' of https://github.com/PaddlePaddl…
lxd-cumt Sep 20, 2023
b717e56
change _use_new_ir_api to in_pir_mode
lxd-cumt Sep 20, 2023
df93664
add _static_guard
lxd-cumt Sep 20, 2023
96f6088
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 20, 2023
89978d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 20, 2023
d0ec100
fix
zhangbo9674 Sep 20, 2023
4013187
fix executor cuda700 error caused by full and full_like
Charles-hit Sep 20, 2023
0785012
Merge commit 'refs/pull/57526/head' of https://github.com/PaddlePaddl…
lxd-cumt Sep 20, 2023
7961a0b
refine
zhangbo9674 Sep 20, 2023
1f18795
Merge commit 'refs/pull/57503/head' of https://github.com/PaddlePaddl…
Charles-hit Sep 20, 2023
33127c1
Merge commit 'refs/pull/57526/head' of https://github.com/PaddlePaddl…
lxd-cumt Sep 20, 2023
ccde153
add vjp rules
lxd-cumt Sep 20, 2023
ca46f4c
fix bugs
lxd-cumt Sep 20, 2023
4a27de0
add scope
lxd-cumt Sep 20, 2023
466b92b
add test
lxd-cumt Sep 20, 2023
2ec15d8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 20, 2023
49587bc
add add op prim rules
lxd-cumt Sep 21, 2023
d2e976a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,20 @@
]


PRIM_VJP = ['divide_grad', 'sum_grad'] # vjp list of primitive op
CUSTOM_VJP = ['gelu_grad'] # custom vjp list of composite op
PRIM_VJP = [
'divide_grad',
'sum_grad',
'cast_grad',
'add_grad',
'multiply_grad',
'elementwise_pow_grad',
'reshape_grad',
'split_grad',
'tanh_grad',
'transpose_grad',
'concat_grad',
] # vjp list of primitive op
CUSTOM_VJP = ['gelu_grad', 'layer_norm_grad'] # custom vjp list of composite op
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

BACKENDS = [
Expand Down Expand Up @@ -149,6 +161,8 @@
'embedding_grad',
'sqrt',
'uniform',
'split',
'transpose',
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{
{% else %}
std::vector<paddle::Tensor*> {{api.outputs[i].name}}(stop_gradients[{{i}}].size(), nullptr);
for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
{{api.outputs[i].name}} = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr;
{{api.outputs[i].name}}[i] = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr;
}
{% endif %}
{% endfor %}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/primitive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
- erf
- tanh
- full
- cast
Loading