diff --git a/paddle/fluid/extension/include/dtype.h b/paddle/fluid/extension/include/dtype.h index 0e392605e1788..e01e94a6a726d 100644 --- a/paddle/fluid/extension/include/dtype.h +++ b/paddle/fluid/extension/include/dtype.h @@ -13,17 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { enum DataType { FLOAT32, FLOAT64, + BFLOAT16, + COMPLEX128, + COMPLEX64, + FLOAT16, INT64, INT32, + INT16, UINT8, INT8, - // TODO(yangjiabin): Add other dtype support in next PR + // TODO(JiabinYang) support more data types if needed. }; } // namespace paddle diff --git a/paddle/fluid/extension/include/tensor.h b/paddle/fluid/extension/include/tensor.h index c072f168197d7..bf6cd63f24d15 100644 --- a/paddle/fluid/extension/include/tensor.h +++ b/paddle/fluid/extension/include/tensor.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include - #include "paddle/fluid/extension/include/dtype.h" #include "paddle/fluid/extension/include/place.h" @@ -29,21 +28,24 @@ class Tensor { explicit Tensor(const PlaceType& place); /// \brief Reset the shape of the tensor. /// Generally it's only used for the input tensor. - /// Reshape must be called before calling mutable_data() or copy_from_cpu() + /// Reshape must be called before calling + /// mutable_data() or copy_from_cpu() /// \param shape The shape to set. - void Reshape(const std::vector& shape); + void reshape(const std::vector& shape); - /// \brief Get the memory pointer in CPU or GPU with specific data type. + /// \brief Get the memory pointer in CPU or GPU with + /// specific data type. /// Please Reshape the tensor first before call this. /// It's usually used to get input data pointer. - /// \param place The place of the tensor this will override the original place - /// of current tensor. + /// \param place The place of the tensor this will + /// override the original place of current tensor. template T* mutable_data(const PlaceType& place); - /// \brief Get the memory pointer in CPU or GPU with specific data type. - /// Please Reshape the tensor first before call this. - /// It's usually used to get input data pointer. + /// \brief Get the memory pointer in CPU or GPU with + /// specific data type. Please Reshape the tensor + /// first before call this.It's usually used to get + /// input data pointer. template T* mutable_data(); @@ -55,27 +57,21 @@ class Tensor { /// \brief Copy the host memory to tensor data. /// It's usually used to set the input tensor data. - /// \param data The pointer of the data, from which the tensor will copy. + /// \param data The pointer of the data, from which + /// the tensor will copy. template - void copy_from_cpu(const T* data); + Tensor copy_to_gpu(); /// \brief Copy the tensor data to the host memory. /// It's usually used to get the output tensor data. - /// \param[out] data The tensor will copy the data to the address. + /// \param[out] data The tensor will copy the data to + /// the address. template - void copy_to_cpu(T* data); + Tensor copy_to_cpu(); /// \brief Return the shape of the Tensor. std::vector shape() const; - /// \brief Set lod info of the tensor. - /// More about LOD can be seen here: - /// https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor - /// \param x the lod info. - void SetLoD(const std::vector>& x); - /// \brief Return the lod info of the tensor. - std::vector> lod() const; - /// \brief Return the data type of the tensor. /// It's usually used to get the output tensor data type. /// \return The data type of the tensor. diff --git a/paddle/fluid/extension/src/tensor.cc b/paddle/fluid/extension/src/tensor.cc index db65923f1f5fb..35fec36b04cea 100644 --- a/paddle/fluid/extension/src/tensor.cc +++ b/paddle/fluid/extension/src/tensor.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/extension/include/tensor.h" - #include - #include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" @@ -30,7 +28,7 @@ namespace paddle { } \ auto *tensor = static_cast(tensor_.get()); -void Tensor::Reshape(const std::vector &shape) { +void Tensor::reshape(const std::vector &shape) { GET_CASTED_TENSOR tensor->Resize(framework::make_ddim(shape)); } @@ -85,16 +83,29 @@ DataType Tensor::type() const { return DataType::INT64; } else if (type == framework::proto::VarType::INT32) { return DataType::INT32; + } else if (type == framework::proto::VarType::INT16) { + return DataType::INT16; + } else if (type == framework::proto::VarType::INT8) { + return DataType::INT8; } else if (type == framework::proto::VarType::UINT8) { return DataType::UINT8; } else if (type == framework::proto::VarType::FP64) { return DataType::FLOAT64; + } else if (type == framework::proto::VarType::BF16) { + return DataType::BFLOAT16; + } else if (type == framework::proto::VarType::FP16) { + return DataType::FLOAT16; + } else if (type == framework::proto::VarType::COMPLEX64) { + return DataType::COMPLEX64; + } else if (type == framework::proto::VarType::COMPLEX128) { + return DataType::COMPLEX128; } return DataType::FLOAT32; } template -void Tensor::copy_from_cpu(const T *data) { +Tensor Tensor::copy_to_gpu() { +#ifdef PADDLE_WITH_CUDA GET_CASTED_TENSOR; PADDLE_ENFORCE_GE(tensor->numel(), 0, platform::errors::PreconditionNotMet( @@ -102,67 +113,85 @@ void Tensor::copy_from_cpu(const T *data) { "std::vector &shape)" "function before copying data from cpu.")); size_t ele_size = tensor->numel() * sizeof(T); - - if (place_ == PlaceType::kCPU) { - auto *t_data = tensor->mutable_data(platform::CPUPlace()); - std::memcpy(static_cast(t_data), data, ele_size); + Tensor target = Tensor(PlaceType::kGPU); + target.reshape(shape()); + auto *p_target_data = target.template mutable_data(); + auto p_src_data = tensor->data(); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + int device_num = platform::GetCurrentDeviceId(); + platform::CUDAPlace gpu_place(device_num); + auto *dev_ctx = + static_cast(pool.Get(gpu_place)); + if (platform::is_cpu_place(tensor->place())) { + memory::Copy(gpu_place, static_cast(p_target_data), + platform::CPUPlace(), p_src_data, ele_size, dev_ctx->stream()); } else { -#ifdef PADDLE_WITH_CUDA - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - int device_num = platform::GetCurrentDeviceId(); - platform::CUDAPlace gpu_place(device_num); - auto *t_data = tensor->mutable_data(gpu_place); - auto *dev_ctx = - static_cast(pool.Get(gpu_place)); - - memory::Copy(gpu_place, static_cast(t_data), platform::CPUPlace(), - data, ele_size, dev_ctx->stream()); + memory::Copy(gpu_place, static_cast(p_target_data), gpu_place, + p_src_data, ele_size, dev_ctx->stream()); + } + cudaStreamSynchronize(dev_ctx->stream()); + return target; #else - PADDLE_THROW(platform::errors::Unavailable( - "Not compiled with CUDA, should not reach here.")); + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle is not compiled with CUDA")); #endif - } + return Tensor(PlaceType::kGPU); } template -void Tensor::copy_to_cpu(T *data) { +Tensor Tensor::copy_to_cpu() { GET_CASTED_TENSOR; auto ele_num = tensor->numel(); auto *t_data = tensor->data(); auto t_place = tensor->place(); - + Tensor target = Tensor(PlaceType::kCPU); + target.reshape(shape()); + auto *p_target_data = target.template mutable_data(); if (platform::is_cpu_place(t_place)) { - std::memcpy(static_cast(data), t_data, ele_num * sizeof(T)); + std::memcpy(static_cast(p_target_data), t_data, + ele_num * sizeof(T)); } else { #ifdef PADDLE_WITH_CUDA platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place); auto *dev_ctx = static_cast(pool.Get(gpu_place)); - memory::Copy(platform::CPUPlace(), static_cast(data), gpu_place, - t_data, ele_num * sizeof(T), dev_ctx->stream()); + memory::Copy(platform::CPUPlace(), static_cast(p_target_data), + gpu_place, t_data, ele_num * sizeof(T), dev_ctx->stream()); cudaStreamSynchronize(dev_ctx->stream()); #else PADDLE_THROW(platform::errors::Unavailable( - "Not compile with CUDA, should not reach here.")); + "PaddlePaddle is not compiled with CUDA.")); #endif } + return target; } -template void Tensor::copy_from_cpu(const float *data); -template void Tensor::copy_from_cpu(const double *data); -template void Tensor::copy_from_cpu(const int64_t *data); -template void Tensor::copy_from_cpu(const int32_t *data); -template void Tensor::copy_from_cpu(const uint8_t *data); -template void Tensor::copy_from_cpu(const int8_t *data); - -template void Tensor::copy_to_cpu(float *data); -template void Tensor::copy_to_cpu(double *data); -template void Tensor::copy_to_cpu(int64_t *data); -template void Tensor::copy_to_cpu(int32_t *data); -template void Tensor::copy_to_cpu(uint8_t *data); -template void Tensor::copy_to_cpu(int8_t *data); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); +template Tensor Tensor::copy_to_gpu(); + +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); +template Tensor Tensor::copy_to_cpu(); template float *Tensor::data() const; template double *Tensor::data() const; @@ -170,6 +199,15 @@ template int64_t *Tensor::data() const; template int32_t *Tensor::data() const; template uint8_t *Tensor::data() const; template int8_t *Tensor::data() const; +template paddle::platform::float16 *Tensor::data() + const; +template paddle::platform::bfloat16 *Tensor::data() + const; +template paddle::platform::complex128 * +Tensor::data() const; +template paddle::platform::complex64 * +Tensor::data() const; +template int16_t *Tensor::data() const; template float *Tensor::mutable_data(); template double *Tensor::mutable_data(); @@ -177,6 +215,15 @@ template int64_t *Tensor::mutable_data(); template int32_t *Tensor::mutable_data(); template uint8_t *Tensor::mutable_data(); template int8_t *Tensor::mutable_data(); +template paddle::platform::float16 * +Tensor::mutable_data(); +template paddle::platform::bfloat16 * +Tensor::mutable_data(); +template paddle::platform::complex128 * +Tensor::mutable_data(); +template paddle::platform::complex64 * +Tensor::mutable_data(); +template int16_t *Tensor::mutable_data(); template float *Tensor::mutable_data(const PlaceType &place); template double *Tensor::mutable_data(const PlaceType &place); @@ -184,30 +231,21 @@ template int64_t *Tensor::mutable_data(const PlaceType &place); template int32_t *Tensor::mutable_data(const PlaceType &place); template uint8_t *Tensor::mutable_data(const PlaceType &place); template int8_t *Tensor::mutable_data(const PlaceType &place); +template paddle::platform::float16 * +Tensor::mutable_data(const PlaceType &place); +template paddle::platform::bfloat16 * +Tensor::mutable_data(const PlaceType &place); +template paddle::platform::complex128 * +Tensor::mutable_data(const PlaceType &place); +template paddle::platform::complex64 * +Tensor::mutable_data(const PlaceType &place); +template int16_t *Tensor::mutable_data(const PlaceType &place); std::vector Tensor::shape() const { GET_CASTED_TENSOR return framework::vectorize(tensor->dims()); } -void Tensor::SetLoD(const std::vector> &x) { - GET_CASTED_TENSOR; - framework::LoD lod; - for (auto &level : x) { - lod.emplace_back(level); - } - tensor->set_lod(lod); -} - -std::vector> Tensor::lod() const { - GET_CASTED_TENSOR; - std::vector> res; - for (auto &level : tensor->lod()) { - res.emplace_back(level); - } - return res; -} - const PlaceType &Tensor::place() const { GET_CASTED_TENSOR; if (platform::is_cpu_place(tensor->place())) { diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4b992a5adf6d6..c0d6f5f389a59 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -323,6 +323,7 @@ configure_file(commit.h.in commit.h) cc_library(custom_tensor SRCS ../extension/src/tensor.cc DEPS lod_tensor) cc_library(op_meta_info SRCS ../extension/src/op_meta_info.cc DEPS custom_tensor) cc_library(custom_operator SRCS custom_operator.cc DEPS operator op_registry device_context dynamic_loader custom_tensor op_meta_info) +cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor) set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 7eb8b2e646583..2621f7ab4e269 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -25,8 +25,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/extension/include/all.h" - +#include "paddle/fluid/extension/include/tensor.h" #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/c/c_api.h" #include "paddle/fluid/framework/custom_tensor_utils.h" @@ -104,25 +103,59 @@ PlaceType ConvertInnerPlaceToEnumPlace(const platform::Place& pc) { return PlaceType::kUNK; } -proto::VarType::Type ConvertEnumDTypeToInnerDType(const DataType& dtype) { +proto::VarType::Type ConvertEnumDTypeToInnerDType( + const paddle::DataType& dtype) { switch (dtype) { - case DataType::FLOAT32: - return proto::VarType::FP32; - case DataType::FLOAT64: + case paddle::DataType::COMPLEX128: + return proto::VarType::COMPLEX128; + case paddle::DataType::COMPLEX64: + return proto::VarType::COMPLEX64; + case paddle::DataType::FLOAT64: return proto::VarType::FP64; - // TODO(chenweihang): + case paddle::DataType::FLOAT32: + return proto::VarType::FP32; + case paddle::DataType::FLOAT16: + return proto::VarType::FP16; + case paddle::DataType::BFLOAT16: + return proto::VarType::BF16; + case paddle::DataType::UINT8: + return proto::VarType::UINT8; + case paddle::DataType::INT8: + return proto::VarType::INT8; + case paddle::DataType::INT32: + return proto::VarType::INT32; + case paddle::DataType::INT64: + return proto::VarType::INT64; default: PADDLE_THROW(platform::errors::Unimplemented("Unsupported data type.")); } } -DataType ConvertInnerDTypeToEnumDType(const proto::VarType::Type& dtype) { +paddle::DataType ConvertInnerDTypeToEnumDType( + const proto::VarType::Type& dtype) { switch (dtype) { - case proto::VarType::FP32: - return DataType::FLOAT32; + case proto::VarType::COMPLEX128: + return paddle::DataType::COMPLEX128; + case proto::VarType::COMPLEX64: + return paddle::DataType::COMPLEX64; case proto::VarType::FP64: - return DataType::FLOAT64; - // TODO(chenweihang): + return paddle::DataType::FLOAT64; + case proto::VarType::FP32: + return paddle::DataType::FLOAT32; + case proto::VarType::FP16: + return paddle::DataType::FLOAT16; + case proto::VarType::BF16: + return paddle::DataType::BFLOAT16; + case proto::VarType::INT64: + return paddle::DataType::INT64; + case proto::VarType::INT32: + return paddle::DataType::INT32; + case proto::VarType::INT8: + return paddle::DataType::INT8; + case proto::VarType::UINT8: + return paddle::DataType::UINT8; + case proto::VarType::INT16: + return paddle::DataType::INT16; default: PADDLE_THROW(platform::errors::Unimplemented("Unsupported data type.")); } diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc new file mode 100644 index 0000000000000..6688bedee2685 --- /dev/null +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "paddle/extension.h" +#include "paddle/fluid/framework/lod_tensor.h" + +template +paddle::Tensor InitCPUTensorForTest() { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + auto* p_data_ptr = t1.mutable_data(paddle::PlaceType::kCPU); + for (int64_t i = 0; i < t1.size(); i++) { + p_data_ptr[i] = 5; + } + return t1; +} +template +void TestCopyTensor() { + auto t1 = InitCPUTensorForTest(); + auto t1_cpu_cp = t1.template copy_to_cpu(); + CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place())); + for (int64_t i = 0; i < t1.size(); i++) { + CHECK_EQ(t1_cpu_cp.template data()[i], 5); + } + auto t1_gpu_cp = t1_cpu_cp.template copy_to_gpu(); + CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place())); + auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to_gpu(); + CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place())); + auto t1_gpu_cp_cp_cpu = t1_gpu_cp.template copy_to_cpu(); + CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place())); + for (int64_t i = 0; i < t1.size(); i++) { + CHECK_EQ(t1_gpu_cp_cp_cpu.template data()[i], 5); + } +} + +void TestAPIPlace() { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kGPU); + t1.reshape(tensor_shape); + t1.mutable_data(); + auto t2 = paddle::Tensor(paddle::PlaceType::kCPU); + t2.reshape(tensor_shape); + t2.mutable_data(); + CHECK((paddle::PlaceType::kGPU == t1.place())); + CHECK((paddle::PlaceType::kCPU == t2.place())); +} + +void TestAPISizeAndShape() { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + CHECK_EQ(t1.size(), 25); + CHECK(t1.shape() == tensor_shape); +} + +template +paddle::DataType TestDtype() { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + t1.template mutable_data(); + return t1.type(); +} + +void GroupTestCopy() { + VLOG(0) << "Float cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(0) << "Double cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + // TODO(JiabinYang): Support these test later + // VLOG(0) << "Fp16 cpu-cpu-gpu-gpu-cpu"; + // TestCopyTensor(); + // VLOG(0) << "BF16 cpu-cpu-gpu-gpu-cpu"; + // TestCopyTensor(); + // VLOG(0) << "complex128 cpu-cpu-gpu-gpu-cpu"; + // TestCopyTensor(); + // VLOG(0) << "complex64 cpu-cpu-gpu-gpu-cpu"; + // TestCopyTensor(); + // VLOG(0) << "int cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(0) << "int64 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(0) << "int16 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(0) << "int8 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(0) << "uint8 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); +} +void GroupTestDtype() { + CHECK(TestDtype() == paddle::DataType::FLOAT32); + CHECK(TestDtype() == paddle::DataType::FLOAT64); + CHECK(TestDtype() == paddle::DataType::FLOAT16); + CHECK(TestDtype() == paddle::DataType::BFLOAT16); + CHECK(TestDtype() == + paddle::DataType::COMPLEX128); + CHECK(TestDtype() == + paddle::DataType::COMPLEX64); + CHECK(TestDtype() == paddle::DataType::INT32); + CHECK(TestDtype() == paddle::DataType::INT64); + CHECK(TestDtype() == paddle::DataType::INT16); + CHECK(TestDtype() == paddle::DataType::INT8); + CHECK(TestDtype() == paddle::DataType::UINT8); +} + +TEST(CustomTensor, copyTest) { + VLOG(0) << "TestCopy"; + GroupTestCopy(); + VLOG(0) << "TestDtype"; + GroupTestDtype(); + VLOG(0) << "TestShape"; + TestAPISizeAndShape(); + VLOG(0) << "TestPlace"; + TestAPIPlace(); +} diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc b/python/paddle/fluid/tests/custom_op/relu_op_simple.cc index a5c4b271124c4..684466a734147 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc +++ b/python/paddle/fluid/tests/custom_op/relu_op_simple.cc @@ -39,7 +39,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, std::vector relu_cpu_forward(const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kCPU); - out.Reshape(x.shape()); + out.reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { @@ -54,7 +54,7 @@ std::vector relu_cpu_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU); - grad_x.Reshape(x.shape()); + grad_x.reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { relu_cpu_backward_kernel( diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu b/python/paddle/fluid/tests/custom_op/relu_op_simple.cu index b3d15824424bb..a9ce517607093 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu +++ b/python/paddle/fluid/tests/custom_op/relu_op_simple.cu @@ -37,7 +37,7 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, std::vector relu_cuda_forward(const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kGPU); - out.Reshape(x.shape()); + out.reshape(x.shape()); int numel = x.size(); int block = 512; @@ -55,7 +55,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU); - grad_x.Reshape(x.shape()); + grad_x.reshape(x.shape()); int numel = out.size(); int block = 512;