Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PRIM][PIR]Migrate prim rules (PaddlePaddle#57554)
* fix bugs of generating Op::Build when Op has optional tensor * add default constructor for IrMetaTensor * fix bugs * polish guard * pir support prim gelu and rsqrt * support prim bwd ops * migrate vjp rules of cast,add,multiply,elementwise_pow * add cast as primitive op * fix bugs in elementwise_pow_grad * add test for cast_grad * add test for elementwise_add_grad * add test for elementwise_mul_grad * add test for elementwise_pow_grad * fix bugs * fix bugs * support pir prim backward ops * refien * fix bug * migrate layer_norm custom vjp rules to pir * fix bugs in ir_backward * fix backward , scope, and concat_grad prim * add layer_norm fwd decompose logic * fix pow * change _use_new_ir_api to in_pir_mode * add _static_guard * fix * fix executor cuda700 error caused by full and full_like * refine * add vjp rules * fix bugs * add scope * add test * add add op prim rules --------- Co-authored-by: YuanRisheng <yuanrisheng@baidu.com> Co-authored-by: cyber-pioneer <chenzhuo@tju.edu.cn> Co-authored-by: Charles-hit <wanghao107@baidu.com> Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
- Loading branch information