Skip to content

Commit

Permalink
[IR] Auto gen fused op (PaddlePaddle#56585)
Browse files Browse the repository at this point in the history
* add code

* fix bug

* fix bug
  • Loading branch information
zhangbo9674 authored and BeingGod committed Sep 9, 2023
1 parent 281957c commit 2e0a6da
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 11 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
{def_primitive}
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ set(op_backward_yaml_file1
set(op_backward_yaml_file2
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml
)
set(fused_op_forward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml
)
set(fused_op_backward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml
)
set(op_yaml_file3
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml)

set(op_yaml_files
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3}
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3}
)
set(op_namespace paddle,dialect)
set(dialect_name pd)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len)
infer_meta :
func: EmbeddingWithEltwiseAddXPUInferMeta
param : [ids, tables, mask]
kernel:
func: embedding_with_eltwise_add_xpu
data_type: tables
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
const std::vector<const MetaTensor*>& gather_index,
const MetaTensor& rotary_pos_emb,
const MetaTensor& time_step,
const MetaTensor& seq_lengths,
const MetaTensor& src_mask,
const MetaTensor& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
const std::vector<const MetaTensor*>& gather_index,
const MetaTensor& rotary_pos_emb,
const MetaTensor& time_step,
const MetaTensor& seq_lengths,
const MetaTensor& src_mask,
const MetaTensor& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
Expand Down

0 comments on commit 2e0a6da

Please sign in to comment.