Skip to content

Commit

Permalink
[Phi] Replace Backend by Place in C++ API (#40732)
Browse files Browse the repository at this point in the history
* replace Backend by Place in C++ API

* fix left code

* fix test_to_api bug
  • Loading branch information
zyfncg committed Mar 22, 2022
1 parent 67b46e4 commit 5b7fade
Show file tree
Hide file tree
Showing 22 changed files with 73 additions and 89 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::Backend::GPU);
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::GPUPlace());
barrierTensors.push_back(dt);
}
auto task = ProcessGroupNCCL::AllReduce(barrierTensors);
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/distributed/collective/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ EagerReducer::EagerReducer(
if (find_unused_vars_each_step_) {
global_used_vars_ = paddle::experimental::empty(
ScalarArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32,
TransToBackend(inner_place_));
inner_place_);
}
}

Expand Down Expand Up @@ -363,10 +363,8 @@ void EagerReducer::InitializeGroups(
} else {
// process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group);
// experimental::Backend backend = TransToBackend(inner_place_);
group.dense_contents_ = paddle::experimental::empty(
ScalarArray({group.all_length_}), group.dtype_,
TransToBackend(inner_place_));
ScalarArray({group.all_length_}), group.dtype_, inner_place_);
}

// map tensors to this group by VariableLocator
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/eager/api/utils/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue(
const phi::DataType& dtype, const phi::DataLayout& layout, float value,
bool is_leaf) {
paddle::experimental::Tensor out = paddle::experimental::full(
phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype,
phi::TransToPhiBackend(place));
phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype, place);

auto meta = EagerUtils::autograd_meta(&out);
if (is_leaf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def SkipAPIGeneration(forward_api_name):
"std::vector<std::string>": "CastPyArg2Strings",
"paddle::experimental::Scalar": "CastPyArg2Scalar",
"paddle::experimental::ScalarArray": "CastPyArg2ScalarArray",
"paddle::experimental::Backend": "CastPyArg2Backend",
"paddle::experimental::Place": "CastPyArg2Place",
"paddle::experimental::DataType": "CastPyArg2DataType",
}

Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ void InitTensorWithTensor(TensorObject* self,
self->tensor.set_impl(impl);
VLOG(4) << "Same place, do ShareDataWith";
} else {
self->tensor.set_impl(
src.copy_to(phi::TransToPhiBackend(place), true).impl());
self->tensor.set_impl(src.copy_to(place, true).impl());
VLOG(4) << "Different place, do TensorCopy";
}
if (src.get_autograd_meta()) {
Expand All @@ -156,8 +155,7 @@ void InitTensorWithFrameworkTensor(TensorObject* self,
} else {
auto temp =
paddle::experimental::Tensor(std::make_shared<phi::DenseTensor>(src));
self->tensor.set_impl(
temp.copy_to(phi::TransToPhiBackend(place), true).impl());
self->tensor.set_impl(temp.copy_to(place, true).impl());
VLOG(4) << "Different place, do TensorCopy";
}
egr::EagerUtils::autograd_meta(&(self->tensor))->SetPersistable(false);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/eager_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args,
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3);

dst = src.copy_to(phi::TransToPhiBackend(place), blocking);
dst = src.copy_to(place, blocking);
egr::EagerUtils::autograd_meta(&dst)->SetStopGradient(
egr::EagerUtils::autograd_meta(&(src))->StopGradient());
egr::EagerUtils::autograd_meta(&dst)->SetPersistable(
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
EAGER_TRY
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
auto cp_tensor =
self->tensor.copy_to(phi::TransToPhiBackend(place), blocking);
auto cp_tensor = self->tensor.copy_to(place, blocking);
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable(
Expand All @@ -231,8 +230,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
static PyObject* tensor_method_cpu(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto cp_tensor =
self->tensor.copy_to(phi::TransToPhiBackend(phi::CPUPlace()), true);
auto cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable(
Expand Down
26 changes: 4 additions & 22 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -929,28 +929,10 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs(
return result;
}

paddle::experimental::Backend CastPyArg2Backend(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int or place, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}

PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
if (type_name == "int") {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return static_cast<paddle::experimental::Backend>(value);
} else {
platform::Place place = CastPyArg2Place(obj, arg_pos);
return phi::TransToPhiBackend(place);
}

return paddle::experimental::Backend::CPU;
paddle::experimental::Place CastPyArg2Place(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
return CastPyArg2Place(obj, arg_pos);
}

paddle::experimental::DataType CastPyArg2DataType(PyObject* obj,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
paddle::experimental::ScalarArray CastPyArg2ScalarArray(
PyObject* obj, const std::string& op_type, ssize_t arg_pos);

paddle::experimental::Backend CastPyArg2Backend(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
paddle::experimental::Place CastPyArg2Place(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);

paddle::experimental::DataType CastPyArg2DataType(PyObject* obj,
const std::string& op_type,
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ using gpuStream_t = hipStream_t;

#include "paddle/phi/api/ext/dll_decl.h"
#include "paddle/phi/api/ext/place.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
Expand Down Expand Up @@ -415,11 +414,11 @@ class PADDLE_API Tensor final {
/**
* @brief Transfer the current Tensor to the specified device and return.
*
* @param backend, The target backend of which the tensor will copy to.
* @param place, The target place of which the tensor will copy to.
* @param blocking, Should we copy this in sync way.
* @return Tensor
*/
Tensor copy_to(Backend backend, bool blocking) const;
Tensor copy_to(Place place, bool blocking) const;

/**
* @brief Transfer the source Tensor to current Tensor.
Expand Down
9 changes: 5 additions & 4 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h"
Expand All @@ -31,9 +32,10 @@ limitations under the License. */
namespace paddle {
namespace experimental {

Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
Expand All @@ -57,8 +59,7 @@ Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
phi::DenseTensor*);

auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx, *dense_x, phi::TransToPhiPlace(backend), blocking, kernel_out);
(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);

return out;
}
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ limitations under the License. */
#pragma once

#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"

namespace paddle {
namespace experimental {

// TODO(chenweihang): Replace backend by place when place is ready
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking);
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking);

std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
Expand Down
10 changes: 7 additions & 3 deletions paddle/phi/api/lib/kernel_dispatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) {
return dtype != DataType::UNDEFINED ? dtype : ParseDataType(tensor);
}

Backend ParseBackend(Backend backend) { return backend; }
Backend ParseBackend(const Place& place) {
return phi::TransToPhiBackend(place);
}
Backend ParseBackend(const Tensor& tensor) {
return phi::TransToPhiBackend(tensor.inner_place());
}

Backend ParseBackendWithInputOrder(Backend backend, const Tensor& tensor) {
return backend != Backend::UNDEFINED ? backend : ParseBackend(tensor);
Backend ParseBackendWithInputOrder(const Place& place, const Tensor& tensor) {
return place.GetType() != phi::AllocationType::UNDEFINED
? ParseBackend(place)
: ParseBackend(tensor);
}

DataLayout ParseLayout(DataLayout layout) { return layout; }
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ DataType ParseDataType(const Tensor& tensor);
DataType ParseDataType(const std::vector<Tensor>& tensors);
DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor);

Backend ParseBackend(Backend backend);
Backend ParseBackend(const Place& place);
Backend ParseBackend(const Tensor& tensor);
template <typename T, typename... Args>
Backend ParseBackend(T t, Args... args) {
Expand All @@ -163,7 +163,7 @@ Backend ParseBackend(T t, Args... args) {
return static_cast<Backend>(64 -
detail::CountLeadingZeros(backend_set.bitset()));
}
Backend ParseBackendWithInputOrder(Backend backend, const Tensor& tensor);
Backend ParseBackendWithInputOrder(const Place& place, const Tensor& tensor);

DataLayout ParseLayout(DataLayout layout);
DataLayout ParseLayout(const Tensor& tensor);
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/api/lib/tensor_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ namespace paddle {
namespace experimental {
// declare cast api
Tensor cast(const Tensor &x, DataType out_dtype);
Tensor copy_to(const Tensor &x, Backend backend, bool blocking);
Tensor copy_to(const Tensor &x, Place place, bool blocking);

Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type);
}

Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::copy_to(*this, backend, blocking);
Tensor Tensor::copy_to(Place place, bool blocking) const {
return experimental::copy_to(*this, place, blocking);
}

template <typename T>
Expand All @@ -44,7 +44,7 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
"`copy_to` method without template argument instead. "
"reason: copying a Tensor to another device does not need "
"to specify the data type template argument.";
return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
return copy_to(ConvertExtPlaceToInnerPlace(target_place), /*blocking=*/false);
}

template PADDLE_API Tensor
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/common/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,10 @@ namespace paddle {
namespace experimental {
using AllocationType = phi::AllocationType;
using Place = phi::Place;
using CPUPlace = phi::CPUPlace;
using GPUPlace = phi::GPUPlace;
using GPUPinnedPlace = phi::GPUPinnedPlace;
using XPUPlace = phi::XPUPlace;
using NPUPlace = phi::NPUPlace;
} // namespace experimental
} // namespace paddle
11 changes: 6 additions & 5 deletions paddle/phi/tests/api/test_data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */

#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -39,10 +40,10 @@ TEST(API, data_transform_same_place) {
auto x = paddle::experimental::full({3, 3},
1.0,
experimental::DataType::COMPLEX128,
experimental::Backend::CPU);
experimental::CPUPlace());

auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::Backend::CPU);
{3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::CPUPlace());

std::vector<phi::dtype::complex<double>> sum(9, 6.0);

Expand Down Expand Up @@ -74,10 +75,10 @@ TEST(API, data_transform_same_place) {
TEST(Tensor, data_transform_diff_place) {
// 1. create tensor
auto x = paddle::experimental::full(
{3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::Backend::CPU);
{3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::CPUPlace());

auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::Backend::GPU);
{3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::GPUPlace());

std::vector<float> sum(9, 6.0);

Expand All @@ -95,7 +96,7 @@ TEST(Tensor, data_transform_diff_place) {
ASSERT_EQ(out.impl()->place(),
phi::TransToPhiPlace(experimental::Backend::GPU));

auto ref_out = experimental::copy_to(out, experimental::Backend::CPU, true);
auto ref_out = experimental::copy_to(out, experimental::CPUPlace(), true);

auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(ref_out.impl());
for (size_t i = 0; i < 9; i++) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/tests/api/test_scale_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tests {

TEST(API, scale) {
auto x = experimental::full(
{3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::Backend::CPU);
{3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::CPUPlace());

const size_t cycles = 300;
phi::tests::Timer timer;
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/tests/api/test_to_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ TEST(API, copy_to) {

// 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = paddle::experimental::copy_to(x, phi::Backend::GPU, false);
auto out = paddle::experimental::copy_to(tmp, phi::Backend::CPU, true);
auto tmp = paddle::experimental::copy_to(x, phi::GPUPlace(), false);
auto out = paddle::experimental::copy_to(tmp, phi::CPUPlace(), true);
#else
auto out = paddle::experimental::copy_to(x, phi::Backend::CPU, false);
auto out = paddle::experimental::copy_to(x, phi::CPUPlace(), false);
#endif

// 3. check result
Expand All @@ -85,10 +85,10 @@ TEST(Tensor, copy_to) {

// 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = x.copy_to(phi::Backend::GPU, false);
auto out = tmp.copy_to(phi::Backend::CPU, true);
auto tmp = x.copy_to(phi::GPUPlace(), false);
auto out = tmp.copy_to(phi::CPUPlace(), true);
#else
auto out = x.copy_to(phi::Backend::CPU, false);
auto out = x.copy_to(phi::CPUPlace(), false);
#endif

// 3. check result
Expand Down
Loading

0 comments on commit 5b7fade

Please sign in to comment.