diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/aarch64/acl_post_ops.hpp index b49b97624e9..7966694f715 100644 --- a/src/cpu/aarch64/acl_post_ops.hpp +++ b/src/cpu/aarch64/acl_post_ops.hpp @@ -34,6 +34,12 @@ struct acl_post_ops_t { status_t init(engine_t *engine, post_ops_t &post_ops, const memory_desc_t &dst_md) { + // Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs + // the post op in f32 and then casts down to f16 while ACL runs the post op in f16 + // leading to a loss of accuracy compared to ref. + ACL_CHECK_SUPPORT( + post_ops.len() >= 1 && dst_md.data_type == data_type::f16, + "post ops cannot be executed in fp16"); CHECK(post_ops.set_default_formats(&dst_md)); // Reset properties derived from post_ops @@ -128,6 +134,12 @@ struct acl_post_ops_t { const memory_desc_t &dst_md, arm_compute::ActivationLayerInfo &act_info_to_fuse) { + // Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs + // the post op in f32 and then casts down to f16 while ACL runs the post op in f16 + // leading to a loss of accuracy compared to ref. + ACL_CHECK_SUPPORT( + base_post_ops.len() >= 1 && dst_md.data_type == data_type::f16, + "post ops cannot be executed in fp16"); CHECK(base_post_ops.set_default_formats(&dst_md)); // If the first entry is eltwise, we fuse it