Skip to content

Commit

Permalink
[PIR] Pir onednn ops yaml (#61474)
Browse files Browse the repository at this point in the history
* support 44 onednn ops's yaml
  • Loading branch information
wanghuancoder authored Feb 5, 2024
1 parent aba5e38 commit a4c6d3d
Show file tree
Hide file tree
Showing 20 changed files with 646 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ static phi::Attribute ConvertPirAttribute2RuntimeAttribute(
}
}
return vec_res;
} else if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
std::vector<int64_t> int_array =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();
return int_array;
} else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
phi::DataType dtype =
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
return dtype;
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"ConvertPirAttribute2RuntimeAttribute not support [%s] ",
Expand Down
109 changes: 50 additions & 59 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

# - op : add_n

# - op : add_raw

- op : batch_norm
extra_args : bool fuse_with_relu=false
data_format_tensors : x
Expand All @@ -27,13 +25,15 @@
extra_args : bool fuse_with_relu=false
data_format_tensors : x, out_grad

# - op : bilinear_interp
- op : bilinear_interp

# - op : cast

# - op : clip
- op : clip
extra_args : str mkldnn_data_type="float32"

# - op : clip_grad
- op : clip_grad
extra_args : str mkldnn_data_type="float32"

# - op : concat

Expand All @@ -59,19 +59,17 @@

# - op : depthwise_conv2d_grad

# - op : divide

# - op : divide_grad
- op : divide

# - op : divide_raw
- op : divide_grad

# - op : elu
- op : elu

# - op : elu_grad
- op : elu_grad

# - op : exp
- op : exp

# - op : exp_grad
- op : exp_grad

# - op : expand

Expand All @@ -87,7 +85,7 @@

# - op : flatten2_grad

# - op : full
- op : full

- op : fused_conv2d
extra_args : float fuse_alpha = 0.0, float fuse_beta = 0.0, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}
Expand Down Expand Up @@ -115,23 +113,25 @@

# - op : fusion_lstm

# - op : gaussian
- op : gaussian

# - op : gelu
- op : gelu
extra_args : str mkldnn_data_type="float32"

# - op : gelu_grad
- op : gelu_grad
extra_args : str mkldnn_data_type="float32"

# - op : hardswish
- op : hardswish

# - op : hardswish_grad
- op : hardswish_grad

# - op : layer_norm

# - op : leaky_relu
- op : leaky_relu

# - op : leaky_relu_grad
- op : leaky_relu_grad

# - op : log_softmax
- op : log_softmax

- op : lrn
extra_args : bool is_test=false
Expand All @@ -153,33 +153,25 @@

# - op : matmul_with_flatten_grad

# - op : max

# - op : max_raw

# - op : mean

# - op : mean_grad
- op : max

# - op : mean_raw
- op : mean

# - op : min
- op : mean_grad

# - op : min_raw
- op : min

# - op : mish
- op : mish

# - op : mish_grad
- op : mish_grad

# - op : multi_gru
- op : multi_gru

# - op : multiply
- op : multiply

# - op : multiply_grad
- op : multiply_grad

# - op : multiply_raw

# - op : nearest_interp
- op : nearest_interp

# - op : pad

Expand Down Expand Up @@ -215,21 +207,22 @@

# - op : reshape2_grad

# - op : round
- op : round

# - op : scale
- op : scale

# - op : sgd
- op : sgd

# - op : sgd_dense_param_sparse_grad

# - op : shape
# extra_args : str mkldnn_data_type="float32"

# - op : shuffle_channel
- op : shuffle_channel

# - op : sigmoid
- op : sigmoid

# - op : sigmoid_grad
- op : sigmoid_grad

# - op : slice

Expand All @@ -239,15 +232,15 @@

# - op : softmax_grad

# - op : softplus
- op : softplus

# - op : split

# - op : split_with_num

# - op : sqrt
- op : sqrt

# - op : sqrt_grad
- op : sqrt_grad

# - op : squeeze

Expand All @@ -257,25 +250,23 @@

# - op : stack

# - op : subtract

# - op : subtract_grad

# - op : subtract_raw
- op : subtract

# - op : sum
- op : subtract_grad

# - op : sum_grad
- op : sum
extra_args : str mkldnn_data_type="float32"

# - op : sum_raw
- op : sum_grad
extra_args : str mkldnn_data_type="float32"

# - op : swish

# - op : swish_grad

# - op : tanh
- op : tanh

# - op : tanh_grad
- op : tanh_grad

# - op : transpose

Expand Down
Loading

0 comments on commit a4c6d3d

Please sign in to comment.