diff --git a/paddle/fluid/extension/include/dtype.h b/paddle/fluid/extension/include/dtype.h index e01e94a6a726d..3db1f5c308471 100644 --- a/paddle/fluid/extension/include/dtype.h +++ b/paddle/fluid/extension/include/dtype.h @@ -32,6 +32,7 @@ enum DataType { INT16, UINT8, INT8, + BOOL, // TODO(JiabinYang) support more data types if needed. }; diff --git a/paddle/fluid/extension/include/tensor.h b/paddle/fluid/extension/include/tensor.h index 02f32a6c146f5..1140efe5c1906 100644 --- a/paddle/fluid/extension/include/tensor.h +++ b/paddle/fluid/extension/include/tensor.h @@ -59,17 +59,11 @@ class Tensor { /// \brief Copy the host memory to tensor data. /// It's usually used to set the input tensor data. - /// \param data The pointer of the data, from which + /// \param PlaceType of target place, from which /// the tensor will copy. - template - Tensor copy_to_gpu(); - /// \brief Copy the tensor data to the host memory. - /// It's usually used to get the output tensor data. - /// \param[out] data The tensor will copy the data to - /// the address. template - Tensor copy_to_cpu(); + Tensor copy_to(const PlaceType& place); /// \brief Return the shape of the Tensor. std::vector shape() const; @@ -89,6 +83,9 @@ class Tensor { /// \return Place. const PlaceType& place() const; + /// \brief Cast datatype from one to another + Tensor cast(const DataType& target_type); + private: friend class framework::CustomTensorUtils; mutable std::shared_ptr tensor_; diff --git a/paddle/fluid/extension/src/tensor.cc b/paddle/fluid/extension/src/tensor.cc index 5aef6891a64ca..04c5faf3dc5ef 100644 --- a/paddle/fluid/extension/src/tensor.cc +++ b/paddle/fluid/extension/src/tensor.cc @@ -18,10 +18,76 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { +template +struct CastDataTypeFunctor { + HOSTDEVICE inline OutType operator()(InType in) const { + return static_cast(in); + } +}; + +template +struct CastDataType { + CastDataType(const framework::Tensor &in, framework::Tensor *out, + const platform::DeviceContext *ctx) + : in_(in), out_(out), ctx_(ctx) {} + const framework::Tensor in_; + framework::Tensor *out_; + const platform::DeviceContext *ctx_; + + template + void apply() { + auto *in_begin = in_.data(); + auto *in_end = in_begin + in_.numel(); + auto *out_begin = out_->mutable_data(in_.place()); + + if (platform::is_cpu_place(in_.place())) { + platform::Transform trans; + auto *context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); +#ifdef __NVCC__ + } else if (platform::is_gpu_place(in_.place())) { + platform::Transform trans; + auto *context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); + context->Wait(); +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Place type is not supported when casting data type.")); + } + } +}; +template +void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, + int64_t ele_size) { +#ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + int device_num = paddle::platform::GetCurrentDeviceId(); + platform::CUDAPlace gpu_place(device_num); + auto *dev_ctx = + static_cast(pool.Get(gpu_place)); + if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) { + memory::Copy(platform::CPUPlace(), static_cast(dst), gpu_place, src, + ele_size, dev_ctx->stream()); + } else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) { + memory::Copy(gpu_place, static_cast(dst), gpu_place, src, ele_size, + dev_ctx->stream()); + } else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) { + memory::Copy(gpu_place, static_cast(dst), platform::CPUPlace(), src, + ele_size, dev_ctx->stream()); + } else { + PADDLE_THROW("Only GPU related Copy can reach this func."); + } + cudaStreamSynchronize(dev_ctx->stream()); +#endif +} + #define GET_CASTED_TENSOR \ if (!tensor_) { \ tensor_ = std::make_shared(); \ @@ -55,12 +121,12 @@ T *Tensor::mutable_data() { case static_cast(PlaceType::kCPU): { return tensor->mutable_data(platform::CPUPlace()); } - case static_cast(PlaceType::kGPU): { #ifdef PADDLE_WITH_CUDA + case static_cast(PlaceType::kGPU): { int device_num = platform::GetCurrentDeviceId(); return tensor->mutable_data(platform::CUDAPlace(device_num)); -#endif } +#endif default: PADDLE_THROW(platform::errors::Unavailable( "CustomOp unsupported place: %d", static_cast(place_))); @@ -99,13 +165,14 @@ DataType Tensor::type() const { return DataType::COMPLEX64; } else if (type == framework::proto::VarType::COMPLEX128) { return DataType::COMPLEX128; + } else if (type == framework::proto::VarType::BOOL) { + return DataType::BOOL; } return DataType::FLOAT32; } template -Tensor Tensor::copy_to_gpu() { -#ifdef PADDLE_WITH_CUDA +Tensor Tensor::copy_to(const PlaceType &target_place) { GET_CASTED_TENSOR; PADDLE_ENFORCE_GE(tensor->numel(), 0, platform::errors::PreconditionNotMet( @@ -113,85 +180,47 @@ Tensor Tensor::copy_to_gpu() { "std::vector &shape)" "function before copying data from cpu.")); size_t ele_size = tensor->numel() * sizeof(T); - Tensor target = Tensor(PlaceType::kGPU); + auto *p_src_data = tensor->data(); + auto src_place = place(); + Tensor target = Tensor(target_place); target.reshape(shape()); auto *p_target_data = target.template mutable_data(); - auto p_src_data = tensor->data(); - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - int device_num = platform::GetCurrentDeviceId(); - platform::CUDAPlace gpu_place(device_num); - auto *dev_ctx = - static_cast(pool.Get(gpu_place)); - if (platform::is_cpu_place(tensor->place())) { - memory::Copy(gpu_place, static_cast(p_target_data), - platform::CPUPlace(), p_src_data, ele_size, dev_ctx->stream()); - } else { - memory::Copy(gpu_place, static_cast(p_target_data), gpu_place, - p_src_data, ele_size, dev_ctx->stream()); - } - cudaStreamSynchronize(dev_ctx->stream()); - return target; -#else - PADDLE_THROW( - platform::errors::Unavailable("PaddlePaddle is not compiled with CUDA")); -#endif - return Tensor(PlaceType::kGPU); -} -template -Tensor Tensor::copy_to_cpu() { - GET_CASTED_TENSOR; - auto ele_num = tensor->numel(); - auto *t_data = tensor->data(); - auto t_place = tensor->place(); - Tensor target = Tensor(PlaceType::kCPU); - target.reshape(shape()); - auto *p_target_data = target.template mutable_data(); - if (platform::is_cpu_place(t_place)) { - std::memcpy(static_cast(p_target_data), t_data, - ele_num * sizeof(T)); + if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) { + std::memcpy(static_cast(p_target_data), p_src_data, ele_size); + } else if ((src_place == PlaceType::kGPU) && + (target_place == PlaceType::kCPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kCPU) && + (target_place == PlaceType::kGPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kGPU) && + (target_place == PlaceType::kGPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); } else { -#ifdef PADDLE_WITH_CUDA - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place); - auto *dev_ctx = - static_cast(pool.Get(gpu_place)); - memory::Copy(platform::CPUPlace(), static_cast(p_target_data), - gpu_place, t_data, ele_num * sizeof(T), dev_ctx->stream()); - - cudaStreamSynchronize(dev_ctx->stream()); -#else PADDLE_THROW(platform::errors::Unavailable( - "PaddlePaddle is not compiled with CUDA.")); -#endif + "Not supported place transform of place: %d to place: %d", + static_cast(src_place), static_cast(target_place))); } return target; } -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); -template Tensor Tensor::copy_to_gpu(); - -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); -template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to( + const PlaceType &target_place); +template Tensor Tensor::copy_to( + const PlaceType &target_place); +template Tensor Tensor::copy_to( + const PlaceType &target_place); +template Tensor Tensor::copy_to( + const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); +template Tensor Tensor::copy_to(const PlaceType &target_place); template float *Tensor::data() const; template double *Tensor::data() const; @@ -208,6 +237,7 @@ Tensor::data() const; template paddle::platform::complex64 * Tensor::data() const; template int16_t *Tensor::data() const; +template bool *Tensor::data() const; template float *Tensor::mutable_data(); template double *Tensor::mutable_data(); @@ -224,6 +254,7 @@ Tensor::mutable_data(); template paddle::platform::complex64 * Tensor::mutable_data(); template int16_t *Tensor::mutable_data(); +template bool *Tensor::mutable_data(); template float *Tensor::mutable_data(const PlaceType &place); template double *Tensor::mutable_data(const PlaceType &place); @@ -240,6 +271,7 @@ Tensor::mutable_data(const PlaceType &place); template paddle::platform::complex64 * Tensor::mutable_data(const PlaceType &place); template int16_t *Tensor::mutable_data(const PlaceType &place); +template bool *Tensor::mutable_data(const PlaceType &place); std::vector Tensor::shape() const { GET_CASTED_TENSOR @@ -261,6 +293,62 @@ const PlaceType &Tensor::place() const { return place_; } +Tensor Tensor::cast(const DataType &target_type) { + GET_CASTED_TENSOR; + Tensor rlt = Tensor(place()); + rlt.reshape(this->shape()); + auto rlt_tensor_ = static_cast(rlt.tensor_.get()); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto ctx = pool.Get(tensor->place()); + auto src_type = tensor->type(); + auto dst_type = + framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type); + switch (src_type) { + case framework::proto::VarType::FP16: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::BF16: + framework::VisitDataType(dst_type, CastDataType( + *tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::FP32: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::FP64: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::INT32: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::INT64: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::BOOL: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::INT16: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::UINT8: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + // TODO(JiabinYang): Support Complex later + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when casting data type.", + framework::DataTypeToString(src_type))); + } + return rlt; +} + int64_t Tensor::size() const { GET_CASTED_TENSOR; return tensor->numel(); @@ -273,12 +361,13 @@ void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) { *static_cast(src.tensor_.get())); } -void CustomTensorUtils::ShareDataFrom(void *src, const paddle::Tensor &dst) { +void CustomTensorUtils::ShareDataFrom(const void *src, + const paddle::Tensor &dst) { if (!dst.tensor_) { dst.tensor_ = std::make_shared(); } auto *tensor = static_cast(dst.tensor_.get()); - tensor->ShareDataWith(*static_cast(src)); + tensor->ShareDataWith(*static_cast(src)); } } // namespace framework diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 66f4349f1a746..1e2a77e915dea 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -35,7 +35,6 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/string/string_helper.h" namespace paddle { @@ -76,102 +75,6 @@ inline bool IsMemberOf(const std::vector& vec, } // namespace detail -// PaddlePlace <-> platform::Place -platform::Place ConvertEnumPlaceToInnerPlace(const PlaceType& pc) { - if (pc == PlaceType::kCPU) { - return platform::Place(platform::CPUPlace()); - } else if (pc == PlaceType::kGPU) { -#ifdef PADDLE_WITH_CUDA - return platform::Place(platform::CUDAPlace(platform::GetCurrentDeviceId())); -#endif - } else { - PADDLE_THROW( - platform::errors::Unimplemented("Unsupported place type code(%d) when " - "casting enum place to paddle place.", - static_cast(pc))); - } - return platform::Place(); -} - -PlaceType ConvertInnerPlaceToEnumPlace(const platform::Place& pc) { - if (platform::is_cpu_place(pc)) { - return PlaceType::kCPU; - } else if (platform::is_gpu_place(pc)) { -#ifdef PADDLE_WITH_CUDA - return PlaceType::kGPU; -#endif - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported place type `%s` when casting paddle place to enum place.", - pc)); - } - return PlaceType::kUNK; -} - -proto::VarType::Type ConvertEnumDTypeToInnerDType( - const paddle::DataType& dtype) { - switch (dtype) { - case paddle::DataType::COMPLEX128: - return proto::VarType::COMPLEX128; - case paddle::DataType::COMPLEX64: - return proto::VarType::COMPLEX64; - case paddle::DataType::FLOAT64: - return proto::VarType::FP64; - case paddle::DataType::FLOAT32: - return proto::VarType::FP32; - case paddle::DataType::FLOAT16: - return proto::VarType::FP16; - case paddle::DataType::BFLOAT16: - return proto::VarType::BF16; - case paddle::DataType::UINT8: - return proto::VarType::UINT8; - case paddle::DataType::INT8: - return proto::VarType::INT8; - case paddle::DataType::INT32: - return proto::VarType::INT32; - case paddle::DataType::INT64: - return proto::VarType::INT64; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported data type code(%d) when casting enum data type into " - "paddle data type.", - dtype)); - } -} - -paddle::DataType ConvertInnerDTypeToEnumDType( - const proto::VarType::Type& dtype) { - switch (dtype) { - case proto::VarType::COMPLEX128: - return paddle::DataType::COMPLEX128; - case proto::VarType::COMPLEX64: - return paddle::DataType::COMPLEX64; - case proto::VarType::FP64: - return paddle::DataType::FLOAT64; - case proto::VarType::FP32: - return paddle::DataType::FLOAT32; - case proto::VarType::FP16: - return paddle::DataType::FLOAT16; - case proto::VarType::BF16: - return paddle::DataType::BFLOAT16; - case proto::VarType::INT64: - return paddle::DataType::INT64; - case proto::VarType::INT32: - return paddle::DataType::INT32; - case proto::VarType::INT8: - return paddle::DataType::INT8; - case proto::VarType::UINT8: - return paddle::DataType::UINT8; - case proto::VarType::INT16: - return paddle::DataType::INT16; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported data type `%s` when casting paddle data type into enum " - "data type.", - DataTypeToString(dtype))); - } -} - ////////////////// Kernel Define //////////////////// // custom op kernel call function define @@ -189,8 +92,9 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, PADDLE_ENFORCE_EQ(x->IsInitialized(), true, platform::errors::InvalidArgument( "Input tensor (%s) is not initialized.")); - auto custom_in = paddle::Tensor(ConvertInnerPlaceToEnumPlace(x->place())); - CustomTensorUtils::ShareDataFrom((void*)x, custom_in); // NOLINT + auto custom_in = paddle::Tensor( + CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place())); + CustomTensorUtils::ShareDataFrom(static_cast(x), custom_in); custom_ins.emplace_back(custom_in); } @@ -400,7 +304,8 @@ void RegisterOperatorKernelWithPlace(const std::string& name, const PlaceType& place, const std::vector& inputs, const std::vector& outputs) { - OpKernelType key(type, ConvertEnumPlaceToInnerPlace(place)); + OpKernelType key(type, + CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place)); VLOG(1) << "Custom Operator: op kernel key: " << key; OperatorWithKernel::AllOpKernels()[name][key] = [kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) { @@ -505,7 +410,8 @@ void RegisterOperatorWithMetaInfo( VLOG(1) << "Custom Operator: InferDtype - get input dtype."; for (auto& in_name : op_inputs) { auto dtype = ctx->GetInputDataType(in_name); - input_dtypes.emplace_back(ConvertInnerDTypeToEnumDType(dtype)); + input_dtypes.emplace_back( + CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype)); } VLOG(1) << "Custom Operator: InferDtype - infer output dtype."; @@ -513,8 +419,9 @@ void RegisterOperatorWithMetaInfo( VLOG(1) << "Custom Operator: InferDtype - set output dtype."; for (size_t i = 0; i < op_outputs.size(); ++i) { - ctx->SetOutputDataType(op_outputs[i], - ConvertEnumDTypeToInnerDType(output_dtypes[i])); + ctx->SetOutputDataType( + op_outputs[i], + CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i])); } }; diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index a9e7e02d093d1..1a8ebca4d267f 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/extension.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -30,17 +31,19 @@ paddle::Tensor InitCPUTensorForTest() { template void TestCopyTensor() { auto t1 = InitCPUTensorForTest(); - auto t1_cpu_cp = t1.template copy_to_cpu(); + auto t1_cpu_cp = t1.template copy_to(paddle::PlaceType::kCPU); CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place())); for (int64_t i = 0; i < t1.size(); i++) { CHECK_EQ(t1_cpu_cp.template data()[i], 5); } #ifdef PADDLE_WITH_CUDA - auto t1_gpu_cp = t1_cpu_cp.template copy_to_gpu(); + VLOG(2) << "Do GPU copy test"; + auto t1_gpu_cp = t1_cpu_cp.template copy_to(paddle::PlaceType::kGPU); CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place())); - auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to_gpu(); + auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to(paddle::PlaceType::kGPU); CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place())); - auto t1_gpu_cp_cp_cpu = t1_gpu_cp.template copy_to_cpu(); + auto t1_gpu_cp_cp_cpu = + t1_gpu_cp.template copy_to(paddle::PlaceType::kCPU); CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place())); for (int64_t i = 0; i < t1.size(); i++) { CHECK_EQ(t1_gpu_cp_cp_cpu.template data()[i], 5); @@ -50,13 +53,17 @@ void TestCopyTensor() { void TestAPIPlace() { std::vector tensor_shape = {5, 5}; +#ifdef PADDLE_WITH_CUDA auto t1 = paddle::Tensor(paddle::PlaceType::kGPU); t1.reshape(tensor_shape); t1.mutable_data(); +#endif auto t2 = paddle::Tensor(paddle::PlaceType::kCPU); t2.reshape(tensor_shape); t2.mutable_data(); +#ifdef PADDLE_WITH_CUDA CHECK((paddle::PlaceType::kGPU == t1.place())); +#endif CHECK((paddle::PlaceType::kCPU == t2.place())); } @@ -77,31 +84,63 @@ paddle::DataType TestDtype() { return t1.type(); } +template +void TestCast(paddle::DataType data_type) { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + t1.template mutable_data(); + auto t2 = t1.cast(data_type); + CHECK_EQ(t2.type(), data_type); +} + void GroupTestCopy() { - VLOG(0) << "Float cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "Float cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(0) << "Double cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); // TODO(JiabinYang): Support these test later - // VLOG(0) << "Fp16 cpu-cpu-gpu-gpu-cpu"; + // VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; // TestCopyTensor(); - // VLOG(0) << "BF16 cpu-cpu-gpu-gpu-cpu"; + // VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu"; // TestCopyTensor(); - // VLOG(0) << "complex128 cpu-cpu-gpu-gpu-cpu"; + // VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; // TestCopyTensor(); - // VLOG(0) << "complex64 cpu-cpu-gpu-gpu-cpu"; + // VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; // TestCopyTensor(); - // VLOG(0) << "int cpu-cpu-gpu-gpu-cpu"; + // VLOG(2) << "int cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(0) << "int64 cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(0) << "int16 cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "int16 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(0) << "int8 cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "int8 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(0) << "uint8 cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); } + +void GroupTestCast() { + VLOG(2) << "int cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "int32 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "int64 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "double cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "bfloat16 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "float16 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "bool cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "uint8 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "float cast"; + TestCast(paddle::DataType::FLOAT32); +} + void GroupTestDtype() { CHECK(TestDtype() == paddle::DataType::FLOAT32); CHECK(TestDtype() == paddle::DataType::FLOAT64); @@ -119,12 +158,14 @@ void GroupTestDtype() { } TEST(CustomTensor, copyTest) { - VLOG(0) << "TestCopy"; + VLOG(2) << "TestCopy"; GroupTestCopy(); - VLOG(0) << "TestDtype"; + VLOG(2) << "TestDtype"; GroupTestDtype(); - VLOG(0) << "TestShape"; + VLOG(2) << "TestShape"; TestAPISizeAndShape(); - VLOG(0) << "TestPlace"; + VLOG(2) << "TestPlace"; TestAPIPlace(); + VLOG(2) << "TestCast"; + GroupTestCast(); } diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index 3399091661747..81357529c14cd 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -17,6 +17,9 @@ limitations under the License. */ #include #include "paddle/fluid/extension/include/tensor.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace framework { @@ -31,7 +34,109 @@ class CustomTensorUtils { /// \brief Share data FROM another tensor. /// Use this to pass tensor from op to op /// \return void. - static void ShareDataFrom(void* src, const paddle::Tensor& dst); + static void ShareDataFrom(const void* src, const Tensor& dst); + + static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( + const paddle::DataType& dtype) { + switch (dtype) { + case paddle::DataType::COMPLEX128: + return framework::proto::VarType::COMPLEX128; + case paddle::DataType::COMPLEX64: + return framework::proto::VarType::COMPLEX64; + case paddle::DataType::FLOAT64: + return framework::proto::VarType::FP64; + case paddle::DataType::FLOAT32: + return framework::proto::VarType::FP32; + case paddle::DataType::FLOAT16: + return framework::proto::VarType::FP16; + case paddle::DataType::BFLOAT16: + return framework::proto::VarType::BF16; + case paddle::DataType::UINT8: + return framework::proto::VarType::UINT8; + case paddle::DataType::INT8: + return framework::proto::VarType::INT8; + case paddle::DataType::INT32: + return framework::proto::VarType::INT32; + case paddle::DataType::INT64: + return framework::proto::VarType::INT64; + case paddle::DataType::BOOL: + return framework::proto::VarType::BOOL; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type code(%d) when casting enum data type into " + "paddle data type.", + static_cast(dtype))); + } + } + + static paddle::DataType ConvertInnerDTypeToEnumDType( + const framework::proto::VarType::Type& dtype) { + switch (dtype) { + case framework::proto::VarType::COMPLEX128: + return paddle::DataType::COMPLEX128; + case framework::proto::VarType::COMPLEX64: + return paddle::DataType::COMPLEX64; + case framework::proto::VarType::FP64: + return paddle::DataType::FLOAT64; + case framework::proto::VarType::FP32: + return paddle::DataType::FLOAT32; + case framework::proto::VarType::FP16: + return paddle::DataType::FLOAT16; + case framework::proto::VarType::BF16: + return paddle::DataType::BFLOAT16; + case framework::proto::VarType::INT64: + return paddle::DataType::INT64; + case framework::proto::VarType::INT32: + return paddle::DataType::INT32; + case framework::proto::VarType::INT8: + return paddle::DataType::INT8; + case framework::proto::VarType::UINT8: + return paddle::DataType::UINT8; + case framework::proto::VarType::INT16: + return paddle::DataType::INT16; + case framework::proto::VarType::BOOL: + return paddle::DataType::BOOL; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type `%s` when casting paddle data type into " + "enum data type.", + DataTypeToString(dtype))); + } + } + + // PaddlePlace <-> platform::Place + static platform::Place ConvertEnumPlaceToInnerPlace(const PlaceType& pc) { + if (pc == PlaceType::kCPU) { + return platform::Place(platform::CPUPlace()); + } else if (pc == PlaceType::kGPU) { +#ifdef PADDLE_WITH_CUDA + return platform::Place( + platform::CUDAPlace(platform::GetCurrentDeviceId())); +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported place type code(%d) when " + "casting enum place to paddle place.", + static_cast(pc))); + } + return platform::Place(); + } + + static PlaceType ConvertInnerPlaceToEnumPlace(const platform::Place& pc) { + if (platform::is_cpu_place(pc)) { + return PlaceType::kCPU; + } else if (platform::is_gpu_place(pc)) { +#ifdef PADDLE_WITH_CUDA + return PlaceType::kGPU; +#endif + } else { + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported place type `%s` when " + "casting paddle place to enum place.", + pc)); + } + return PlaceType::kUNK; + } }; } // namespace framework diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 30a2ac2c6f6be..2479f932f32cd 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -97,10 +97,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::VarType::INT16: - framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::VarType::UINT8: - framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; default: PADDLE_THROW(platform::errors::Unimplemented(