Skip to content

Commit

Permalink
Add mechanism for blocking oneDNN cache clearing (#26502)
Browse files Browse the repository at this point in the history
* Add mechanism for blocking oneDNN cache clearing

* Review changes and Add thread guards
  • Loading branch information
grygielski authored Aug 21, 2020
1 parent 7d3e46e commit f390902
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/run_program_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

DECLARE_bool(use_mkldnn);

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -262,6 +267,9 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
}
VLOG(2) << "The number of sub scopes after forward: "
<< out_scope_vec->front()->kids().size();
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) DontClearMKLDNNCache(ctx.GetPlace());
#endif
}
};

Expand Down
18 changes: 15 additions & 3 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,21 @@ MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}

void MKLDNNDeviceContext::ResetBlobMap() const {
VLOG(3) << "Clearing DNNL cache.";
p_blobmap_->clear();
void MKLDNNDeviceContext::ResetBlobMap() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) {
VLOG(3) << "Clearing DNNL cache.";
p_blobmap_->clear();
} else {
VLOG(3) << "Prevented Clearing DNNL cache.";
block_next_cache_clearing_ = false;
}
}

void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
VLOG(3) << "Next DNNL cache clearing has been blocked.";
block_next_cache_clearing_ = true;
}

size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
const mkldnn::engine& GetEngine() const { return engine_; }

// Remove all entries from the blob map
void ResetBlobMap() const;
void ResetBlobMap();

// Prevent next ResetBlobMap()
void BlockNextCacheClearing();

// Get the ShapeBlob size in cur_mkldnn_session_id.
size_t GetShapeBlobSize() const;
Expand All @@ -539,6 +542,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
mkldnn::engine engine_;
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
};
#endif

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ inline void ClearMKLDNNCache(const platform::Place& place) {
}
}

inline void DontClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->BlockNextCacheClearing();
}
}

template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef;
Expand Down

0 comments on commit f390902

Please sign in to comment.