Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick to 2.1][Second fix to #31992] #32664

Merged
merged 1 commit into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
ClearMKLDNNCache(place_);
ClearMKLDNNCache(place_, this);
#endif
}

Expand Down Expand Up @@ -169,6 +169,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes);
Expand Down Expand Up @@ -294,6 +297,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
if (FLAGS_use_mkldnn) EnableMKLDNN(program);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
bool has_feed_ops =
has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
bool has_fetch_ops =
Expand Down Expand Up @@ -576,7 +582,6 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
}
}
}
platform::AttachPointerHashToMKLDNNKey(this, place_);
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
ClearMKLDNNCache(place_);
ClearMKLDNNCache(place_, this);
#endif
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/api/mkldnn_quantizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_);
dev_ctx->ResetBlobMap();
dev_ctx->ResetBlobMap(
paddle::platform::MKLDNNDeviceContext::tls().get_curr_exec());
}

void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CacheTester {
platform::CPUPlace place;
onednn_dev_ctx_ =
dynamic_cast<platform::MKLDNNDeviceContext *>(pool.Get(place));
onednn_dev_ctx_->ResetBlobMap();
onednn_dev_ctx_->ResetBlobMap(nullptr);
}

bool Analyze(unsigned short int num_entries) {
Expand Down
30 changes: 26 additions & 4 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_exec_items_.reset(new ExecMap());
p_mutex_.reset(new std::mutex());
}

Expand All @@ -560,7 +561,7 @@ MKLDNNDeviceContextThreadLocals::Body::~Body() {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
dev_ctx->ResetBlobMap();
dev_ctx->ResetBlobMap(exec_ptr_);
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
Expand Down Expand Up @@ -607,17 +608,34 @@ mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
return cur_stream;
}

void MKLDNNDeviceContext::ResetBlobMap() {
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) {
VLOG(3) << "Clearing DNNL cache.";
p_blobmap_->clear();
// If no specific executor pointer then clear
// everything. For executor pointer then clear only
// objects allocated when using given executor
if (ptr == nullptr) {
p_blobmap_->clear();
} else {
for (auto& v : (*p_exec_items_)[ptr]) {
(v.first)->erase(v.second);
}
p_exec_items_->erase(ptr);
}
} else {
VLOG(3) << "Prevented Clearing DNNL cache.";
block_next_cache_clearing_ = false;
}
}

void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
KeyBlob::iterator it) const {
// Take current executor addess from TLS
// and for this executor's items add the one defined with arguments
(*p_exec_items_)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
VLOG(3) << "Next DNNL cache clearing has been blocked.";
Expand Down Expand Up @@ -682,7 +700,11 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
(*pBlob)[name] = data;
auto el =
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map
// to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first);
} else {
blob_it->second = data; // set data to existing blob
}
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ class MKLDNNDeviceContextThreadLocals {
mkldnn::stream cur_stream;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
void* exec_ptr_ = nullptr;

Body();
~Body();
Expand All @@ -689,6 +690,8 @@ class MKLDNNDeviceContextThreadLocals {
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
void* get_curr_exec(void) const { return exec_ptr_; }
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
Expand Down Expand Up @@ -724,13 +727,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;

using ExecMap = std::unordered_map<
void*, std::vector<std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>>>;

explicit MKLDNNDeviceContext(CPUPlace place);

/* \brief Get the active engine */
const mkldnn::engine& GetEngine() const { return tls().get_engine(); }

// Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;

// Remove all entries from the blob map
void ResetBlobMap();
void ResetBlobMap(void* ptr);

// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
Expand All @@ -753,6 +762,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {

private:
std::shared_ptr<BlobMap> p_blobmap_;
// Map key is pointer of executor and value is a data(iterator in map) needed
// to erase
std::shared_ptr<ExecMap> p_exec_items_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
};
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
return mkldnn::memory::desc({dims}, data_type, format);
}

inline void ClearMKLDNNCache(const platform::Place& place) {
inline void ClearMKLDNNCache(const platform::Place& place,
void* ptr = nullptr) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap();
dev_ctx->ResetBlobMap(ptr);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
Expand Down Expand Up @@ -452,6 +453,9 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
}
// Let's register adress of current executor
paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

// For first thread
if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
Expand Down