diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index d58decadfadca..2e85d521c516f 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -14,6 +14,10 @@ #include "paddle/phi/core/kernel_factory.h" +#include +#include +#include + #include "glog/logging.h" #include "paddle/phi/core/enforce.h" #include "paddle/utils/flags.h" @@ -33,6 +37,10 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel, true, "Whether to use strdie kernel if op support stride."); +PHI_DEFINE_EXPORTED_string(stride_kernel_blacklist, + "", + "It controls the strided kernel subset do not use."); + PD_DECLARE_int32(low_precision_op_list); PD_DECLARE_bool(enable_api_kernel_fallback); PD_DECLARE_bool(run_kp_kernel); @@ -226,14 +234,26 @@ KernelResult KernelFactory::SelectKernelOrThrowError( phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); if (FLAGS_use_stride_kernel && use_strided_kernel) { - auto stride_kernel_iter = iter->second.find( - {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN - ? paddle::experimental::Backend::GPU - : const_kernel_key.backend(), - phi::DataLayout::STRIDED, - const_kernel_key.dtype()}); - if (stride_kernel_iter != iter->second.end()) { - return {stride_kernel_iter->second, false, true}; + std::regex reg(","); + std::unordered_set elems{ + std::sregex_token_iterator(FLAGS_stride_kernel_blacklist.begin(), + FLAGS_stride_kernel_blacklist.end(), + reg, + -1), + std::sregex_token_iterator()}; + elems.erase(""); + + if (!elems.count(kernel_name)) { + auto stride_kernel_iter = iter->second.find( + {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN + ? paddle::experimental::Backend::GPU + : const_kernel_key.backend(), + phi::DataLayout::STRIDED, + const_kernel_key.dtype()}); + if (stride_kernel_iter != iter->second.end()) { + VLOG(1) << "use strided kernel, kernel_name = " << kernel_name; + return {stride_kernel_iter->second, false, true}; + } } }