diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 411fe09c864aa..02ad22f780f8d 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -600,6 +600,8 @@ class MKLDNNDeviceContextThreadLocals { // MKL-DNN stream used for execution of primitives (per-thread) mkldnn::engine cur_engine; mkldnn::stream cur_stream; + std::string key_suffix; // Key identifying current Executor + bool key_attach_thread_id = true; Body(); ~Body(); @@ -612,6 +614,10 @@ class MKLDNNDeviceContextThreadLocals { void log_lib_version(void); const mkldnn::engine& get_engine(void); mkldnn::stream& get_stream(void); + void set_key_suffix(const std::string& suffix) { key_suffix = suffix; } + const std::string& get_key_suffix(void) const { return key_suffix; } + void disable_tid_in_key(void) { key_attach_thread_id = false; } + bool is_tid_used_in_key(void) const { return key_attach_thread_id; } }; MKLDNNDeviceContextThreadLocals() = default; MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = @@ -655,14 +661,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Remove all entries from the blob map void ResetBlobMap(); - // Set a suffix to be added to key - void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } - const std::string& GetKeySuffix(void) const { return key_suffix_; } - - // Disable adding thread ID to the key - void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; } - bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; } - // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -686,8 +684,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext { std::shared_ptr p_blobmap_; std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; - std::string key_suffix_; // Key identifying current Executor - bool key_attach_thread_id_ = true; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 20e6dfe1c3916..35776b9f1e6b8 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -439,14 +439,23 @@ inline void AppendKey(std::string* key, const std::vector& dims) { inline void AttachPointerHashToMKLDNNKey(void* ptr, const platform::Place& place) { if (platform::is_cpu_place(place)) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext* dev_ctx = - (platform::MKLDNNDeviceContext*)pool.Get(place); - dev_ctx->SetKeySuffix("E" + - std::to_string(reinterpret_cast(ptr))); - // When NaiveExecutor/Executor is used no info on thread id is needed in a - // key - dev_ctx->DisableThreadInfoInKey(); + // Static vars will remember first executor and its thread + // so both of them need to be processed by the same thread within + // critical section + static std::mutex static_vars_barrier; + static_vars_barrier.lock(); + static auto first_exec = ptr; + static auto first_thread = ThreadIDasStr(); + static_vars_barrier.unlock(); + + if (first_exec != ptr) { + paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix( + "E" + std::to_string(reinterpret_cast(ptr))); + } + // For first thread + if (first_thread == ThreadIDasStr()) { + paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key(); + } } } @@ -457,13 +466,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx, key.reserve(64); using expand_type = int[]; expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; - key += dev_ctx.GetKeySuffix(); + key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix(); return key; } inline std::string ExtendKeyWithThreadInfoIfNeeded( const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) { - return ((dev_ctx.IsThreadIdUsedInKey() == true) && + return ((paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() == + true) && (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default)) ? key + "-t:" + ThreadIDasStr()