Skip to content

Commit

Permalink
- Candidate fix to PaddlePaddle#31992
Browse files Browse the repository at this point in the history
  • Loading branch information
jczaja committed Apr 7, 2021
1 parent a17c369 commit ca0d6ac
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
16 changes: 6 additions & 10 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,8 @@ class MKLDNNDeviceContextThreadLocals {
// MKL-DNN stream used for execution of primitives (per-thread)
mkldnn::engine cur_engine;
mkldnn::stream cur_stream;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;

Body();
~Body();
Expand All @@ -612,6 +614,10 @@ class MKLDNNDeviceContextThreadLocals {
void log_lib_version(void);
const mkldnn::engine& get_engine(void);
mkldnn::stream& get_stream(void);
void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
Expand Down Expand Up @@ -655,14 +661,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map
void ResetBlobMap();

// Set a suffix to be added to key
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }

// Disable adding thread ID to the key
void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }
bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }

// Prevent next ResetBlobMap()
void BlockNextCacheClearing();

Expand All @@ -686,8 +684,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
bool key_attach_thread_id_ = true;
};
#endif

Expand Down
30 changes: 20 additions & 10 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" +
std::to_string(reinterpret_cast<uintptr_t>(ptr)));
// When NaiveExecutor/Executor is used no info on thread id is needed in a
// key
dev_ctx->DisableThreadInfoInKey();
// Static vars will remember first executor and its thread
// so both of them need to be processed by the same thread within
// critical section
static std::mutex static_vars_barrier;
static_vars_barrier.lock();
static auto first_exec = ptr;
static auto first_thread = ThreadIDasStr();
static_vars_barrier.unlock();

if (first_exec != ptr) {
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
}
// For first thread
if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
}
}
}

Expand All @@ -457,13 +466,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
key.reserve(64);
using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += dev_ctx.GetKeySuffix();
key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
return key;
}

inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return ((dev_ctx.IsThreadIdUsedInKey() == true) &&
return ((paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
true) &&
(platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
? key + "-t:" + ThreadIDasStr()
Expand Down

0 comments on commit ca0d6ac

Please sign in to comment.