Skip to content

Commit

Permalink
Revert "[PHI] delete dense_tensor mem_desc_ (PaddlePaddle#59918)"
Browse files Browse the repository at this point in the history
This reverts commit 4b56993.
  • Loading branch information
chalsliu committed Dec 14, 2023
1 parent 5edf9b7 commit 4708689
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 24 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/core/dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ DenseTensor::DenseTensor(const DenseTensor& other) {
storage_properties_ =
std::move(CopyStorageProperties(other.storage_properties_));
inplace_version_counter_ = other.inplace_version_counter_;

#ifdef PADDLE_WITH_DNNL
mem_desc_ = other.mem_desc_;
#endif
}

DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
Expand All @@ -70,6 +74,9 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
storage_properties_ =
std::move(CopyStorageProperties(other.storage_properties_));
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_DNNL
mem_desc_ = other.mem_desc_;
#endif
return *this;
}

Expand All @@ -78,6 +85,9 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) noexcept {
std::swap(holder_, other.holder_);
storage_properties_ = std::move(other.storage_properties_);
std::swap(inplace_version_counter_, other.inplace_version_counter_);
#ifdef PADDLE_WITH_DNNL
mem_desc_ = other.mem_desc_;
#endif
return *this;
}

Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/core/dense_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/utils/test_macros.h"

/* @jim19930609: Move to MKLDNN_Tensor in the future
*/
#ifdef PADDLE_WITH_DNNL
#include "dnnl.hpp" // NOLINT
#endif

namespace phi {

class DenseTensorUtils;
Expand Down Expand Up @@ -284,6 +290,18 @@ class TEST_API DenseTensor : public TensorBase,
std::shared_ptr<InplaceVersion> inplace_version_counter_ =
std::make_shared<InplaceVersion>();

/* @jim19930609: This is a hack
In general, it is badly designed to fuse MKLDNN-specific objects into a
generic Tensor.
We temporarily leave them here to unblock Tensor Unification progress.
In the final state, we should come up with a MKLDNN_Tensor and move the
following codes there.
*/
#ifdef PADDLE_WITH_DNNL
/// \brief memory descriptor of tensor which have layout set as kMKLDNN
dnnl::memory::desc mem_desc_;
#endif

#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/phi/core/dense_tensor.inl"
#endif
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/core/dense_tensor.inl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ following codes there.
public:
const dnnl::memory::desc& mem_desc() const;

void set_mem_desc(const dnnl::memory::desc& mem_desc);
inline void set_mem_desc(const dnnl::memory::desc& mem_desc) {
mem_desc_ = mem_desc;
meta_.layout = DataLayout::ONEDNN;
}

#endif

Expand Down
27 changes: 4 additions & 23 deletions paddle/phi/core/dense_tensor_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,29 +377,7 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
}

#ifdef PADDLE_WITH_DNNL
const dnnl::memory::desc& DenseTensor::mem_desc() const {
if (storage_properties_ == nullptr) {
std::unique_ptr<StorageProperties>* storage_properties_ptr =
const_cast<std::unique_ptr<StorageProperties>*>(&storage_properties_);
*storage_properties_ptr = std::make_unique<OneDNNStorageProperties>();
}
return this->storage_properties<OneDNNStorageProperties>().mem_desc;
}

void DenseTensor::set_mem_desc(const dnnl::memory::desc& mem_desc) {
if (storage_properties_ == nullptr) {
storage_properties_ = std::make_unique<OneDNNStorageProperties>();
}
if (OneDNNStorageProperties::classof(storage_properties_.get())) {
dynamic_cast<OneDNNStorageProperties*>(storage_properties_.get())
->mem_desc = mem_desc;
meta_.layout = DataLayout::ONEDNN;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The actual type of storage_properties is inconsistent with the type "
"of the template parameter passed in."));
}
}
const dnnl::memory::desc& DenseTensor::mem_desc() const { return mem_desc_; }
#endif

// NOTE: For historical reasons, this interface has a special behavior,
Expand All @@ -416,6 +394,9 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
meta_.strides = src.meta_.strides;
storage_properties_ =
std::move(CopyStorageProperties(src.storage_properties_));
#ifdef PADDLE_WITH_DNNL
mem_desc_ = src.mem_desc_;
#endif
return *this;
}

Expand Down

0 comments on commit 4708689

Please sign in to comment.