diff --git a/src/gpu/generic/sycl/ref_softmax.hpp b/src/gpu/generic/sycl/ref_softmax.hpp index 5bf7506442d..43f30e54a35 100644 --- a/src/gpu/generic/sycl/ref_softmax.hpp +++ b/src/gpu/generic/sycl/ref_softmax.hpp @@ -111,7 +111,7 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t { && dst_md()->data_type == diff_dst_md()->data_type && attr()->has_default_values() && set_default_formats() == status::success - && check_formats(src_md(), dst_md()) + && check_formats(diff_src_md(), diff_dst_md()) && md_dims_in_range(diff_dst_md()); if (!ok) return status::unimplemented;