Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Pir onednn ops yaml #61474

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ static phi::Attribute ConvertPirAttribute2RuntimeAttribute(
}
}
return vec_res;
} else if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
std::vector<int64_t> int_array =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();
return int_array;
} else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
phi::DataType dtype =
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
return dtype;
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"ConvertPirAttribute2RuntimeAttribute not support [%s] ",
Expand Down
109 changes: 50 additions & 59 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

# - op : add_n

# - op : add_raw

- op : batch_norm
extra_args : bool fuse_with_relu=false
data_format_tensors : x
Expand All @@ -27,13 +25,15 @@
extra_args : bool fuse_with_relu=false
data_format_tensors : x, out_grad

# - op : bilinear_interp
- op : bilinear_interp

# - op : cast

# - op : clip
- op : clip
extra_args : str mkldnn_data_type="float32"

# - op : clip_grad
- op : clip_grad
extra_args : str mkldnn_data_type="float32"

# - op : concat

Expand All @@ -59,19 +59,17 @@

# - op : depthwise_conv2d_grad

# - op : divide

# - op : divide_grad
- op : divide

# - op : divide_raw
- op : divide_grad

# - op : elu
- op : elu

# - op : elu_grad
- op : elu_grad

# - op : exp
- op : exp

# - op : exp_grad
- op : exp_grad

# - op : expand

Expand All @@ -87,7 +85,7 @@

# - op : flatten2_grad

# - op : full
- op : full

- op : fused_conv2d
extra_args : float fuse_alpha = 0.0, float fuse_beta = 0.0, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}
Expand Down Expand Up @@ -115,23 +113,25 @@

# - op : fusion_lstm

# - op : gaussian
- op : gaussian

# - op : gelu
- op : gelu
extra_args : str mkldnn_data_type="float32"

# - op : gelu_grad
- op : gelu_grad
extra_args : str mkldnn_data_type="float32"

# - op : hardswish
- op : hardswish

# - op : hardswish_grad
- op : hardswish_grad

# - op : layer_norm

# - op : leaky_relu
- op : leaky_relu

# - op : leaky_relu_grad
- op : leaky_relu_grad

# - op : log_softmax
- op : log_softmax

- op : lrn
extra_args : bool is_test=false
Expand All @@ -153,33 +153,25 @@

# - op : matmul_with_flatten_grad

# - op : max

# - op : max_raw

# - op : mean

# - op : mean_grad
- op : max

# - op : mean_raw
- op : mean

# - op : min
- op : mean_grad

# - op : min_raw
- op : min

# - op : mish
- op : mish

# - op : mish_grad
- op : mish_grad

# - op : multi_gru
- op : multi_gru

# - op : multiply
- op : multiply

# - op : multiply_grad
- op : multiply_grad

# - op : multiply_raw

# - op : nearest_interp
- op : nearest_interp

# - op : pad

Expand Down Expand Up @@ -215,21 +207,22 @@

# - op : reshape2_grad

# - op : round
- op : round

# - op : scale
- op : scale

# - op : sgd
- op : sgd

# - op : sgd_dense_param_sparse_grad

# - op : shape
# extra_args : str mkldnn_data_type="float32"

# - op : shuffle_channel
- op : shuffle_channel

# - op : sigmoid
- op : sigmoid

# - op : sigmoid_grad
- op : sigmoid_grad

# - op : slice

Expand All @@ -239,15 +232,15 @@

# - op : softmax_grad

# - op : softplus
- op : softplus

# - op : split

# - op : split_with_num

# - op : sqrt
- op : sqrt

# - op : sqrt_grad
- op : sqrt_grad

# - op : squeeze

Expand All @@ -257,25 +250,23 @@

# - op : stack

# - op : subtract

# - op : subtract_grad

# - op : subtract_raw
- op : subtract

# - op : sum
- op : subtract_grad

# - op : sum_grad
- op : sum
extra_args : str mkldnn_data_type="float32"

# - op : sum_raw
- op : sum_grad
extra_args : str mkldnn_data_type="float32"

# - op : swish

# - op : swish_grad

# - op : tanh
- op : tanh

# - op : tanh_grad
- op : tanh_grad

# - op : transpose

Expand Down
Loading