Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor copy to api && change Reshape to lowercase && support more dtype && add more test #3

Merged
merged 89 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from 88 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
136e4de
fix compile error
JiabinYang Jan 27, 2021
2e4fc0d
wrap framework tensor with LoDTensor
JiabinYang Jan 29, 2021
618d917
fix compile error
JiabinYang Jan 29, 2021
d49f476
fix compile error
JiabinYang Jan 29, 2021
d14cd51
fix compile error
JiabinYang Jan 29, 2021
4e719fc
fix compile error
JiabinYang Jan 29, 2021
87059c5
fix compile error
JiabinYang Jan 29, 2021
2ef89a5
add CustomTensor default constructor
JiabinYang Feb 1, 2021
f217ccb
add size() for CustomTensor
JiabinYang Feb 1, 2021
7f9f1cd
make size const for CustomTensor
JiabinYang Feb 1, 2021
8102863
refactor place related api to circle the concept
JiabinYang Feb 1, 2021
74bfc55
merge new op_functor
JiabinYang Feb 1, 2021
c5b3b5c
fix compile error
JiabinYang Feb 1, 2021
c67e36f
fix compile error
JiabinYang Feb 1, 2021
bb4c295
fix compile error
JiabinYang Feb 1, 2021
d416fdb
fix compile error
JiabinYang Feb 1, 2021
8cc60ec
fix compile error
JiabinYang Feb 1, 2021
1dccc2d
fix compile error
JiabinYang Feb 1, 2021
4b304f2
fix compile error
JiabinYang Feb 1, 2021
dcda6cd
fix compile error
JiabinYang Feb 1, 2021
bec954f
fix compile error
JiabinYang Feb 1, 2021
2c5edac
fix compile error
JiabinYang Feb 1, 2021
6990b99
fix compile error
JiabinYang Feb 1, 2021
55b6a13
fix compile error
JiabinYang Feb 1, 2021
abaa67e
fix compile error
JiabinYang Feb 2, 2021
f8b23d4
fix compile error
JiabinYang Feb 2, 2021
ce4ecd0
fix compile error
JiabinYang Feb 2, 2021
0bb004c
fix compile error
JiabinYang Feb 2, 2021
33ad438
fix compile error
JiabinYang Feb 2, 2021
2e433cc
fix compile error
JiabinYang Feb 2, 2021
4e26c71
merge final op_function
JiabinYang Feb 2, 2021
6c1752e
make place const
JiabinYang Feb 2, 2021
a4d190b
make Tensor copy
JiabinYang Feb 2, 2021
b9dde0a
debug CustomTensor core
JiabinYang Feb 3, 2021
219746a
debug CustomTensor core
JiabinYang Feb 3, 2021
bedd624
debug CustomTensor core
JiabinYang Feb 3, 2021
a148ea2
debug CustomTensor core
JiabinYang Feb 3, 2021
1757e3a
debug CustomTensor core
JiabinYang Feb 3, 2021
1815a0f
debug CustomTensor core
JiabinYang Feb 3, 2021
b1e94cd
debug CustomTensor core
JiabinYang Feb 3, 2021
dbd0e17
debug CustomTensor core
JiabinYang Feb 3, 2021
984d11f
debug CustomTensor core
JiabinYang Feb 3, 2021
1d2eae7
debug CustomTensor core
JiabinYang Feb 3, 2021
eda48e8
debug CustomTensor core
JiabinYang Feb 3, 2021
284125c
debug CustomTensor core
JiabinYang Feb 3, 2021
0851daa
debug CustomTensor core
JiabinYang Feb 3, 2021
ea98ccb
debug CustomTensor core
JiabinYang Feb 3, 2021
e04bd30
remove additional head of framework
JiabinYang Feb 3, 2021
1c0cd18
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
aa09b08
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
330b650
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
743a91f
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
9b8917b
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
627fa2e
use back to shared ptr for custom tensor
JiabinYang Feb 3, 2021
7ecffc0
add gpu test
JiabinYang Feb 4, 2021
a52b8ee
merge latest cwh code in
JiabinYang Feb 4, 2021
687c9ef
merge latest cwh code in
JiabinYang Feb 4, 2021
a9cd76a
adjust ut code of custom op
JiabinYang Feb 4, 2021
2afe58a
adjust ut code of custom op
JiabinYang Feb 4, 2021
5693375
adjust ut code of custom op
JiabinYang Feb 4, 2021
0332e29
adjust ut code of custom op
JiabinYang Feb 5, 2021
a9a7550
adjust ut code of custom op
JiabinYang Feb 5, 2021
9aa0d69
hid share data from and to
JiabinYang Feb 5, 2021
6bbea36
rename CustomTensor to Tensor
JiabinYang Feb 5, 2021
0e66ee9
merge cwh code
JiabinYang Feb 5, 2021
3fb3f0a
support multi dtype
JiabinYang Feb 7, 2021
dc18813
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
a83c469
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
df6ba59
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
5272c85
remove lod, make reshape lowercase, add copy test and refactor copy api
JiabinYang Feb 7, 2021
cae22da
merge cwh code and add more dtype && change PaddleDtype to DataType
JiabinYang Feb 7, 2021
19a8ff7
fix copy to error
JiabinYang Feb 7, 2021
1b6ecf6
merge cwh code
JiabinYang Feb 7, 2021
07d3795
add more test
JiabinYang Feb 7, 2021
49ed21c
add more test
JiabinYang Feb 7, 2021
9288fff
add more test
JiabinYang Feb 7, 2021
c775ea7
add more test
JiabinYang Feb 7, 2021
db42afc
add more test
JiabinYang Feb 7, 2021
2243035
add more test
JiabinYang Feb 7, 2021
c634ab0
add more test
JiabinYang Feb 7, 2021
46f8758
add more test
JiabinYang Feb 7, 2021
d912a99
add more test
JiabinYang Feb 7, 2021
4735e8d
add more test
JiabinYang Feb 7, 2021
4d78356
add more test
JiabinYang Feb 7, 2021
d886e9b
add more test
JiabinYang Feb 7, 2021
43ed2a7
add more test
JiabinYang Feb 7, 2021
0e7f286
add more test
JiabinYang Feb 7, 2021
d12969c
add more test
JiabinYang Feb 7, 2021
34af5ab
add more test
JiabinYang Feb 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion paddle/fluid/extension/include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {

enum DataType {
FLOAT32,
FLOAT64,
BFLOAT16,
COMPLEX128,
COMPLEX64,
FLOAT16,
INT64,
INT32,
INT16,
UINT8,
INT8,
// TODO(yangjiabin): Add other dtype support in next PR
// TODO(Superjomn) support more data types if needed.
JiabinYang marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace paddle
15 changes: 3 additions & 12 deletions paddle/fluid/extension/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */

#include <memory>
#include <vector>

#include "paddle/fluid/extension/include/dtype.h"
#include "paddle/fluid/extension/include/place.h"

Expand All @@ -31,7 +30,7 @@ class Tensor {
/// Generally it's only used for the input tensor.
/// Reshape must be called before calling mutable_data() or copy_from_cpu()
/// \param shape The shape to set.
void Reshape(const std::vector<int>& shape);
void reshape(const std::vector<int>& shape);

/// \brief Get the memory pointer in CPU or GPU with specific data type.
/// Please Reshape the tensor first before call this.
Expand All @@ -57,25 +56,17 @@ class Tensor {
/// It's usually used to set the input tensor data.
/// \param data The pointer of the data, from which the tensor will copy.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note also need to adjust

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent?

template <typename T>
void copy_from_cpu(const T* data);
Tensor copy_to_gpu();

/// \brief Copy the tensor data to the host memory.
/// It's usually used to get the output tensor data.
/// \param[out] data The tensor will copy the data to the address.
template <typename T>
void copy_to_cpu(T* data);
Tensor copy_to_cpu();

/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;

/// \brief Set lod info of the tensor.
/// More about LOD can be seen here:
/// https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor
/// \param x the lod info.
void SetLoD(const std::vector<std::vector<size_t>>& x);
/// \brief Return the lod info of the tensor.
std::vector<std::vector<size_t>> lod() const;

/// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type.
/// \return The data type of the tensor.
Expand Down
154 changes: 96 additions & 58 deletions paddle/fluid/extension/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/extension/include/tensor.h"

#include <utility>

#include "paddle/fluid/extension/include/all.h"
Copy link
Owner

@chenwhql chenwhql Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
Expand All @@ -30,7 +28,7 @@ namespace paddle {
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());

void Tensor::Reshape(const std::vector<int> &shape) {
void Tensor::reshape(const std::vector<int> &shape) {
GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape));
}
Expand Down Expand Up @@ -85,129 +83,169 @@ DataType Tensor::type() const {
return DataType::INT64;
} else if (type == framework::proto::VarType::INT32) {
return DataType::INT32;
} else if (type == framework::proto::VarType::INT16) {
return DataType::INT16;
} else if (type == framework::proto::VarType::INT8) {
return DataType::INT8;
} else if (type == framework::proto::VarType::UINT8) {
return DataType::UINT8;
} else if (type == framework::proto::VarType::FP64) {
return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BF16) {
return DataType::BFLOAT16;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
}
return DataType::FLOAT32;
}

template <typename T>
void Tensor::copy_from_cpu(const T *data) {
Tensor Tensor::copy_to_gpu() {
#ifdef PADDLE_WITH_CUDA
GET_CASTED_TENSOR;
PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);

if (place_ == PlaceType::kCPU) {
auto *t_data = tensor->mutable_data<T>(platform::CPUPlace());
std::memcpy(static_cast<void *>(t_data), data, ele_size);
Tensor target = Tensor(PlaceType::kGPU);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
auto p_src_data = tensor->data<T>();

platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if (platform::is_cpu_place(tensor->place())) {
memory::Copy(gpu_place, static_cast<void *>(p_target_data),
platform::CPUPlace(), p_src_data, ele_size, dev_ctx->stream());
} else {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *t_data = tensor->mutable_data<T>(gpu_place);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));

memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
data, ele_size, dev_ctx->stream());
memory::Copy(gpu_place, static_cast<void *>(p_target_data), gpu_place,
p_src_data, ele_size, dev_ctx->stream());
}
cudaStreamSynchronize(dev_ctx->stream());
return target;
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compiled with CUDA, should not reach here."));
PADDLE_THROW(platform::errors::Unavailable(
"Not compiled with CUDA, should not reach here."));
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish error message, Paddle is not compiled with CUDA. is ok, should not is not needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#endif
}
return Tensor(PlaceType::kGPU);
}

template <typename T>
void Tensor::copy_to_cpu(T *data) {
Tensor Tensor::copy_to_cpu() {
GET_CASTED_TENSOR;
auto ele_num = tensor->numel();
auto *t_data = tensor->data<T>();
auto t_place = tensor->place();

Tensor target = Tensor(PlaceType::kCPU);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
if (platform::is_cpu_place(t_place)) {
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
std::memcpy(static_cast<void *>(p_target_data), t_data,
ele_num * sizeof(T));
} else {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
memory::Copy(platform::CPUPlace(), static_cast<void *>(data), gpu_place,
t_data, ele_num * sizeof(T), dev_ctx->stream());
memory::Copy(platform::CPUPlace(), static_cast<void *>(p_target_data),
gpu_place, t_data, ele_num * sizeof(T), dev_ctx->stream());

cudaStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"Not compile with CUDA, should not reach here."));
#endif
}
return target;
}

template void Tensor::copy_from_cpu<float>(const float *data);
template void Tensor::copy_from_cpu<double>(const double *data);
template void Tensor::copy_from_cpu<int64_t>(const int64_t *data);
template void Tensor::copy_from_cpu<int32_t>(const int32_t *data);
template void Tensor::copy_from_cpu<uint8_t>(const uint8_t *data);
template void Tensor::copy_from_cpu<int8_t>(const int8_t *data);

template void Tensor::copy_to_cpu<float>(float *data);
template void Tensor::copy_to_cpu<double>(double *data);
template void Tensor::copy_to_cpu<int64_t>(int64_t *data);
template void Tensor::copy_to_cpu<int32_t>(int32_t *data);
template void Tensor::copy_to_cpu<uint8_t>(uint8_t *data);
template void Tensor::copy_to_cpu<int8_t>(int8_t *data);
template Tensor Tensor::copy_to_gpu<paddle::platform::float16>();
template Tensor Tensor::copy_to_gpu<paddle::platform::bfloat16>();
template Tensor Tensor::copy_to_gpu<paddle::platform::complex64>();
template Tensor Tensor::copy_to_gpu<paddle::platform::complex128>();
template Tensor Tensor::copy_to_gpu<float>();
template Tensor Tensor::copy_to_gpu<double>();
template Tensor Tensor::copy_to_gpu<int64_t>();
template Tensor Tensor::copy_to_gpu<int32_t>();
template Tensor Tensor::copy_to_gpu<uint8_t>();
template Tensor Tensor::copy_to_gpu<int8_t>();
template Tensor Tensor::copy_to_gpu<int16_t>();

template Tensor Tensor::copy_to_cpu<paddle::platform::float16>();
template Tensor Tensor::copy_to_cpu<paddle::platform::bfloat16>();
template Tensor Tensor::copy_to_cpu<paddle::platform::complex64>();
template Tensor Tensor::copy_to_cpu<paddle::platform::complex128>();
template Tensor Tensor::copy_to_cpu<float>();
template Tensor Tensor::copy_to_cpu<double>();
template Tensor Tensor::copy_to_cpu<int64_t>();
template Tensor Tensor::copy_to_cpu<int32_t>();
template Tensor Tensor::copy_to_cpu<uint8_t>();
template Tensor Tensor::copy_to_cpu<int8_t>();
template Tensor Tensor::copy_to_cpu<int16_t>();

template float *Tensor::data<float>() const;
template double *Tensor::data<double>() const;
template int64_t *Tensor::data<int64_t>() const;
template int32_t *Tensor::data<int32_t>() const;
template uint8_t *Tensor::data<uint8_t>() const;
template int8_t *Tensor::data<int8_t>() const;
template paddle::platform::float16 *Tensor::data<paddle::platform::float16>()
const;
template paddle::platform::bfloat16 *Tensor::data<paddle::platform::bfloat16>()
const;
template paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template int16_t *Tensor::data<int16_t>() const;

template float *Tensor::mutable_data<float>();
template double *Tensor::mutable_data<double>();
template int64_t *Tensor::mutable_data<int64_t>();
template int32_t *Tensor::mutable_data<int32_t>();
template uint8_t *Tensor::mutable_data<uint8_t>();
template int8_t *Tensor::mutable_data<int8_t>();
template paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>();
template paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template int16_t *Tensor::mutable_data<int16_t>();

template float *Tensor::mutable_data<float>(const PlaceType &place);
template double *Tensor::mutable_data<double>(const PlaceType &place);
template int64_t *Tensor::mutable_data<int64_t>(const PlaceType &place);
template int32_t *Tensor::mutable_data<int32_t>(const PlaceType &place);
template uint8_t *Tensor::mutable_data<uint8_t>(const PlaceType &place);
template int8_t *Tensor::mutable_data<int8_t>(const PlaceType &place);
template paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>(const PlaceType &place);
template paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template int16_t *Tensor::mutable_data<int16_t>(const PlaceType &place);

std::vector<int> Tensor::shape() const {
GET_CASTED_TENSOR
return framework::vectorize<int>(tensor->dims());
}

void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
GET_CASTED_TENSOR;
framework::LoD lod;
for (auto &level : x) {
lod.emplace_back(level);
}
tensor->set_lod(lod);
}

std::vector<std::vector<size_t>> Tensor::lod() const {
GET_CASTED_TENSOR;
std::vector<std::vector<size_t>> res;
for (auto &level : tensor->lod()) {
res.emplace_back(level);
}
return res;
}

const PlaceType &Tensor::place() const {
GET_CASTED_TENSOR;
if (platform::is_cpu_place(tensor->place())) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ configure_file(commit.h.in commit.h)
cc_library(custom_tensor SRCS ../extension/src/tensor.cc DEPS lod_tensor)
cc_library(op_meta_info SRCS ../extension/src/op_meta_info.cc DEPS custom_tensor)
cc_library(custom_operator SRCS custom_operator.cc DEPS operator op_registry device_context dynamic_loader custom_tensor op_meta_info)
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor)

set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader)

Expand Down
Loading