Skip to content

Commit

Permalink
[PRIM][PIR]Migrate prim rules (PaddlePaddle#57554)
Browse files Browse the repository at this point in the history
* fix bugs of generating Op::Build when Op has optional tensor

* add default constructor for IrMetaTensor

* fix bugs

* polish guard

* pir support prim gelu and rsqrt

* support prim bwd ops

* migrate vjp rules of cast,add,multiply,elementwise_pow

* add cast as primitive op

* fix bugs in elementwise_pow_grad

* add test for cast_grad

* add test for elementwise_add_grad

* add test for elementwise_mul_grad

* add test for elementwise_pow_grad

* fix bugs

* fix bugs

* support pir prim backward ops

* refien

* fix bug

* migrate layer_norm custom vjp rules to pir

* fix bugs in ir_backward

* fix backward , scope, and concat_grad prim

* add layer_norm fwd decompose logic

* fix pow

* change _use_new_ir_api to in_pir_mode

* add _static_guard

* fix

* fix executor cuda700 error caused by full and full_like

* refine

* add vjp rules

* fix bugs

* add scope

* add test

* add add op prim rules

---------

Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: cyber-pioneer <chenzhuo@tju.edu.cn>
Co-authored-by: Charles-hit <wanghao107@baidu.com>
Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
  • Loading branch information
5 people authored and jiahy0825 committed Oct 16, 2023
1 parent f694bcc commit 516cfec
Show file tree
Hide file tree
Showing 21 changed files with 882 additions and 107 deletions.
18 changes: 16 additions & 2 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,20 @@
]


PRIM_VJP = ['divide_grad', 'sum_grad'] # vjp list of primitive op
CUSTOM_VJP = ['gelu_grad'] # custom vjp list of composite op
PRIM_VJP = [
'divide_grad',
'sum_grad',
'cast_grad',
'add_grad',
'multiply_grad',
'elementwise_pow_grad',
'reshape_grad',
'split_grad',
'tanh_grad',
'transpose_grad',
'concat_grad',
] # vjp list of primitive op
CUSTOM_VJP = ['gelu_grad', 'layer_norm_grad'] # custom vjp list of composite op
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

BACKENDS = [
Expand Down Expand Up @@ -149,6 +161,8 @@
'embedding_grad',
'sqrt',
'uniform',
'split',
'transpose',
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{
{% else %}
std::vector<paddle::Tensor*> {{api.outputs[i].name}}(stop_gradients[{{i}}].size(), nullptr);
for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
{{api.outputs[i].name}} = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr;
{{api.outputs[i].name}}[i] = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr;
}
{% endif %}
{% endfor %}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/primitive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
- erf
- tanh
- full
- cast
Loading

0 comments on commit 516cfec

Please sign in to comment.