Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cast api && Change copy related api to copy_to && add more test #4

Merged
merged 107 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
136e4de
fix compile error
JiabinYang Jan 27, 2021
2e4fc0d
wrap framework tensor with LoDTensor
JiabinYang Jan 29, 2021
618d917
fix compile error
JiabinYang Jan 29, 2021
d49f476
fix compile error
JiabinYang Jan 29, 2021
d14cd51
fix compile error
JiabinYang Jan 29, 2021
4e719fc
fix compile error
JiabinYang Jan 29, 2021
87059c5
fix compile error
JiabinYang Jan 29, 2021
2ef89a5
add CustomTensor default constructor
JiabinYang Feb 1, 2021
f217ccb
add size() for CustomTensor
JiabinYang Feb 1, 2021
7f9f1cd
make size const for CustomTensor
JiabinYang Feb 1, 2021
8102863
refactor place related api to circle the concept
JiabinYang Feb 1, 2021
74bfc55
merge new op_functor
JiabinYang Feb 1, 2021
c5b3b5c
fix compile error
JiabinYang Feb 1, 2021
c67e36f
fix compile error
JiabinYang Feb 1, 2021
bb4c295
fix compile error
JiabinYang Feb 1, 2021
d416fdb
fix compile error
JiabinYang Feb 1, 2021
8cc60ec
fix compile error
JiabinYang Feb 1, 2021
1dccc2d
fix compile error
JiabinYang Feb 1, 2021
4b304f2
fix compile error
JiabinYang Feb 1, 2021
dcda6cd
fix compile error
JiabinYang Feb 1, 2021
bec954f
fix compile error
JiabinYang Feb 1, 2021
2c5edac
fix compile error
JiabinYang Feb 1, 2021
6990b99
fix compile error
JiabinYang Feb 1, 2021
55b6a13
fix compile error
JiabinYang Feb 1, 2021
abaa67e
fix compile error
JiabinYang Feb 2, 2021
f8b23d4
fix compile error
JiabinYang Feb 2, 2021
ce4ecd0
fix compile error
JiabinYang Feb 2, 2021
0bb004c
fix compile error
JiabinYang Feb 2, 2021
33ad438
fix compile error
JiabinYang Feb 2, 2021
2e433cc
fix compile error
JiabinYang Feb 2, 2021
4e26c71
merge final op_function
JiabinYang Feb 2, 2021
6c1752e
make place const
JiabinYang Feb 2, 2021
a4d190b
make Tensor copy
JiabinYang Feb 2, 2021
b9dde0a
debug CustomTensor core
JiabinYang Feb 3, 2021
219746a
debug CustomTensor core
JiabinYang Feb 3, 2021
bedd624
debug CustomTensor core
JiabinYang Feb 3, 2021
a148ea2
debug CustomTensor core
JiabinYang Feb 3, 2021
1757e3a
debug CustomTensor core
JiabinYang Feb 3, 2021
1815a0f
debug CustomTensor core
JiabinYang Feb 3, 2021
b1e94cd
debug CustomTensor core
JiabinYang Feb 3, 2021
dbd0e17
debug CustomTensor core
JiabinYang Feb 3, 2021
984d11f
debug CustomTensor core
JiabinYang Feb 3, 2021
1d2eae7
debug CustomTensor core
JiabinYang Feb 3, 2021
eda48e8
debug CustomTensor core
JiabinYang Feb 3, 2021
284125c
debug CustomTensor core
JiabinYang Feb 3, 2021
0851daa
debug CustomTensor core
JiabinYang Feb 3, 2021
ea98ccb
debug CustomTensor core
JiabinYang Feb 3, 2021
e04bd30
remove additional head of framework
JiabinYang Feb 3, 2021
1c0cd18
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
aa09b08
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
330b650
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
743a91f
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
9b8917b
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
627fa2e
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
7ecffc0
add gpu test
JiabinYang Feb 4, 2021
a52b8ee
merge latest cwh code in
JiabinYang Feb 4, 2021
687c9ef
merge latest cwh code in
JiabinYang Feb 4, 2021
a9cd76a
adjust ut code of custom op
JiabinYang Feb 4, 2021
2afe58a
adjust ut code of custom op
JiabinYang Feb 4, 2021
5693375
adjust ut code of custom op
JiabinYang Feb 4, 2021
0332e29
adjust ut code of custom op
JiabinYang Feb 5, 2021
a9a7550
adjust ut code of custom op
JiabinYang Feb 5, 2021
9aa0d69
hid share data from and to
JiabinYang Feb 5, 2021
6bbea36
rename CustomTensor to Tensor
JiabinYang Feb 5, 2021
0e66ee9
merge cwh code
JiabinYang Feb 5, 2021
3fb3f0a
support multi dtype
JiabinYang Feb 7, 2021
dc18813
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
a83c469
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
df6ba59
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
5272c85
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
cae22da
merge cwh code and add more dtype && change PaddleDtype to DataType
JiabinYang Feb 7, 2021
19a8ff7
fix copy to error
JiabinYang Feb 7, 2021
1b6ecf6
merge cwh code
JiabinYang Feb 7, 2021
07d3795
add more test
JiabinYang Feb 7, 2021
49ed21c
add more test
JiabinYang Feb 7, 2021
9288fff
add more test
JiabinYang Feb 7, 2021
c775ea7
add more test
JiabinYang Feb 7, 2021
db42afc
add more test
JiabinYang Feb 7, 2021
2243035
add more test
JiabinYang Feb 7, 2021
c634ab0
add more test
JiabinYang Feb 7, 2021
46f8758
add more test
JiabinYang Feb 7, 2021
d912a99
add more test
JiabinYang Feb 7, 2021
4735e8d
add more test
JiabinYang Feb 7, 2021
4d78356
add more test
JiabinYang Feb 7, 2021
d886e9b
add more test
JiabinYang Feb 7, 2021
43ed2a7
add more test
JiabinYang Feb 7, 2021
0e7f286
add more test
JiabinYang Feb 7, 2021
d12969c
add more test
JiabinYang Feb 7, 2021
34af5ab
add more test
JiabinYang Feb 8, 2021
cb63fcb
add type cast
JiabinYang Feb 8, 2021
896e31d
merge cwh code
JiabinYang Feb 8, 2021
325c783
add cast and make copy to api
JiabinYang Feb 8, 2021
f3c897d
add cast and make copy to api
JiabinYang Feb 8, 2021
d506876
add cast and make copy to api
JiabinYang Feb 8, 2021
7afdaf9
add cast and make copy to api
JiabinYang Feb 8, 2021
47d851d
merge cwh code
JiabinYang Feb 8, 2021
853b25e
merge cwh code
JiabinYang Feb 8, 2021
0838234
merge cwh code
JiabinYang Feb 8, 2021
b90eed8
merge cwh code
JiabinYang Feb 8, 2021
9752379
merge cwh code
JiabinYang Feb 8, 2021
4cc3a60
merge cwh code
JiabinYang Feb 8, 2021
c84fe6f
add more error log
JiabinYang Feb 8, 2021
84bd7b2
add more error log
JiabinYang Feb 8, 2021
855fa0a
polish code
JiabinYang Feb 8, 2021
3a65ab6
used for test
JiabinYang Feb 8, 2021
efeb766
remove test comment
JiabinYang Feb 8, 2021
2a43205
remove test comment
JiabinYang Feb 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/extension/include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum DataType {
INT16,
UINT8,
INT8,
BOOL,
// TODO(JiabinYang) support more data types if needed.
};

Expand Down
13 changes: 5 additions & 8 deletions paddle/fluid/extension/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,11 @@ class Tensor {

/// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data.
/// \param data The pointer of the data, from which
/// \param PlaceType of target place, from which
/// the tensor will copy.
template <typename T>
Tensor copy_to_gpu();

/// \brief Copy the tensor data to the host memory.
/// It's usually used to get the output tensor data.
/// \param[out] data The tensor will copy the data to
/// the address.
template <typename T>
Tensor copy_to_cpu();
Tensor copy_to(const PlaceType& place);

/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
Expand All @@ -89,6 +83,9 @@ class Tensor {
/// \return Place.
const PlaceType& place() const;

/// \brief Cast datatype from one to another
Tensor cast(const DataType& target_type);

private:
friend class framework::CustomTensorUtils;
mutable std::shared_ptr<void> tensor_;
Expand Down
243 changes: 166 additions & 77 deletions paddle/fluid/extension/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,76 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {

template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
return static_cast<OutType>(in);
}
};

template <typename InType>
struct CastDataType {
CastDataType(const framework::Tensor &in, framework::Tensor *out,
const platform::DeviceContext *ctx)
: in_(in), out_(out), ctx_(ctx) {}
const framework::Tensor in_;
framework::Tensor *out_;
const platform::DeviceContext *ctx_;

template <typename OutType>
void apply() {
auto *in_begin = in_.data<InType>();
auto *in_end = in_begin + in_.numel();
auto *out_begin = out_->mutable_data<OutType>(in_.place());

if (platform::is_cpu_place(in_.place())) {
platform::Transform<platform::CPUDeviceContext> trans;
auto *context = static_cast<const platform::CPUDeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
} else if (platform::is_gpu_place(in_.place())) {
platform::Transform<platform::CUDADeviceContext> trans;
auto *context = static_cast<const platform::CUDADeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
context->Wait();
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Place type is not supported when casting data type."));
}
}
};
template <typename T>
void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
int64_t ele_size) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) {
memory::Copy(platform::CPUPlace(), static_cast<void *>(dst), gpu_place, src,
ele_size, dev_ctx->stream());
} else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) {
memory::Copy(gpu_place, static_cast<void *>(dst), gpu_place, src, ele_size,
dev_ctx->stream());
} else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) {
memory::Copy(gpu_place, static_cast<void *>(dst), platform::CPUPlace(), src,
ele_size, dev_ctx->stream());
} else {
PADDLE_THROW("Only GPU related Copy can reach this func.");
}
cudaStreamSynchronize(dev_ctx->stream());
#endif
}

#define GET_CASTED_TENSOR \
if (!tensor_) { \
tensor_ = std::make_shared<framework::LoDTensor>(); \
Expand Down Expand Up @@ -55,12 +121,12 @@ T *Tensor::mutable_data() {
case static_cast<int>(PlaceType::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
}
case static_cast<int>(PlaceType::kGPU): {
#ifdef PADDLE_WITH_CUDA
case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId();
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
#endif
}
#endif
default:
PADDLE_THROW(platform::errors::Unavailable(
"CustomOp unsupported place: %d", static_cast<int>(place_)));
Expand Down Expand Up @@ -99,99 +165,62 @@ DataType Tensor::type() const {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL;
}
return DataType::FLOAT32;
}

template <typename T>
Tensor Tensor::copy_to_gpu() {
#ifdef PADDLE_WITH_CUDA
Tensor Tensor::copy_to(const PlaceType &target_place) {
GET_CASTED_TENSOR;
PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);
Tensor target = Tensor(PlaceType::kGPU);
auto *p_src_data = tensor->data<T>();
auto src_place = place();
Tensor target = Tensor(target_place);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
auto p_src_data = tensor->data<T>();

platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if (platform::is_cpu_place(tensor->place())) {
memory::Copy(gpu_place, static_cast<void *>(p_target_data),
platform::CPUPlace(), p_src_data, ele_size, dev_ctx->stream());
} else {
memory::Copy(gpu_place, static_cast<void *>(p_target_data), gpu_place,
p_src_data, ele_size, dev_ctx->stream());
}
cudaStreamSynchronize(dev_ctx->stream());
return target;
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle is not compiled with CUDA"));
#endif
return Tensor(PlaceType::kGPU);
}

template <typename T>
Tensor Tensor::copy_to_cpu() {
GET_CASTED_TENSOR;
auto ele_num = tensor->numel();
auto *t_data = tensor->data<T>();
auto t_place = tensor->place();
Tensor target = Tensor(PlaceType::kCPU);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
if (platform::is_cpu_place(t_place)) {
std::memcpy(static_cast<void *>(p_target_data), t_data,
ele_num * sizeof(T));
if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) {
std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kCPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kCPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
memory::Copy(platform::CPUPlace(), static_cast<void *>(p_target_data),
gpu_place, t_data, ele_num * sizeof(T), dev_ctx->stream());

cudaStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle is not compiled with CUDA."));
#endif
"Not supported place transform of place: %d to place: %d",
static_cast<int>(src_place), static_cast<int>(target_place)));
}
return target;
}

template Tensor Tensor::copy_to_gpu<paddle::platform::float16>();
template Tensor Tensor::copy_to_gpu<paddle::platform::bfloat16>();
template Tensor Tensor::copy_to_gpu<paddle::platform::complex64>();
template Tensor Tensor::copy_to_gpu<paddle::platform::complex128>();
template Tensor Tensor::copy_to_gpu<float>();
template Tensor Tensor::copy_to_gpu<double>();
template Tensor Tensor::copy_to_gpu<int64_t>();
template Tensor Tensor::copy_to_gpu<int32_t>();
template Tensor Tensor::copy_to_gpu<uint8_t>();
template Tensor Tensor::copy_to_gpu<int8_t>();
template Tensor Tensor::copy_to_gpu<int16_t>();

template Tensor Tensor::copy_to_cpu<paddle::platform::float16>();
template Tensor Tensor::copy_to_cpu<paddle::platform::bfloat16>();
template Tensor Tensor::copy_to_cpu<paddle::platform::complex64>();
template Tensor Tensor::copy_to_cpu<paddle::platform::complex128>();
template Tensor Tensor::copy_to_cpu<float>();
template Tensor Tensor::copy_to_cpu<double>();
template Tensor Tensor::copy_to_cpu<int64_t>();
template Tensor Tensor::copy_to_cpu<int32_t>();
template Tensor Tensor::copy_to_cpu<uint8_t>();
template Tensor Tensor::copy_to_cpu<int8_t>();
template Tensor Tensor::copy_to_cpu<int16_t>();
template Tensor Tensor::copy_to<paddle::platform::float16>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<float>(const PlaceType &target_place);
template Tensor Tensor::copy_to<double>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int64_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int32_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<uint8_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int8_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int16_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<bool>(const PlaceType &target_place);

template float *Tensor::data<float>() const;
template double *Tensor::data<double>() const;
Expand All @@ -208,6 +237,7 @@ Tensor::data<paddle::platform::complex128>() const;
template paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template int16_t *Tensor::data<int16_t>() const;
template bool *Tensor::data<bool>() const;

template float *Tensor::mutable_data<float>();
template double *Tensor::mutable_data<double>();
Expand All @@ -224,6 +254,7 @@ Tensor::mutable_data<paddle::platform::complex128>();
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template int16_t *Tensor::mutable_data<int16_t>();
template bool *Tensor::mutable_data<bool>();

template float *Tensor::mutable_data<float>(const PlaceType &place);
template double *Tensor::mutable_data<double>(const PlaceType &place);
Expand All @@ -240,6 +271,7 @@ Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template int16_t *Tensor::mutable_data<int16_t>(const PlaceType &place);
template bool *Tensor::mutable_data<bool>(const PlaceType &place);

std::vector<int> Tensor::shape() const {
GET_CASTED_TENSOR
Expand All @@ -261,6 +293,62 @@ const PlaceType &Tensor::place() const {
return place_;
}

Tensor Tensor::cast(const DataType &target_type) {
GET_CASTED_TENSOR;
Tensor rlt = Tensor(place());
rlt.reshape(this->shape());
auto rlt_tensor_ = static_cast<framework::LoDTensor *>(rlt.tensor_.get());
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto ctx = pool.Get(tensor->place());
auto src_type = tensor->type();
auto dst_type =
framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
switch (src_type) {
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type, CastDataType<platform::float16>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BF16:
framework::VisitDataType(dst_type, CastDataType<platform::bfloat16>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP32:
framework::VisitDataType(dst_type,
CastDataType<float>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP64:
framework::VisitDataType(dst_type,
CastDataType<double>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT32:
framework::VisitDataType(dst_type,
CastDataType<int>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT64:
framework::VisitDataType(
dst_type, CastDataType<int64_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BOOL:
framework::VisitDataType(dst_type,
CastDataType<bool>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT16:
framework::VisitDataType(
dst_type, CastDataType<int16_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::UINT8:
framework::VisitDataType(
dst_type, CastDataType<u_int8_t>(*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang): Support Complex later
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
framework::DataTypeToString(src_type)));
}
return rlt;
}

int64_t Tensor::size() const {
GET_CASTED_TENSOR;
return tensor->numel();
Expand All @@ -273,12 +361,13 @@ void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) {
*static_cast<framework::LoDTensor *>(src.tensor_.get()));
}

void CustomTensorUtils::ShareDataFrom(void *src, const paddle::Tensor &dst) {
void CustomTensorUtils::ShareDataFrom(const void *src,
const paddle::Tensor &dst) {
if (!dst.tensor_) {
dst.tensor_ = std::make_shared<framework::LoDTensor>();
}
auto *tensor = static_cast<framework::LoDTensor *>(dst.tensor_.get());
tensor->ShareDataWith(*static_cast<framework::LoDTensor *>(src));
tensor->ShareDataWith(*static_cast<const framework::LoDTensor *>(src));
}

} // namespace framework
Expand Down
Loading