From 587e1c24cfd0ec4d288ddb580f82dede69a17ea4 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Fri, 16 Sep 2022 07:59:12 +0800 Subject: [PATCH] Support both use_calc_stream and sync_op in send recv APIs (#46023) --- .../distributed/collective/ProcessGroup.h | 50 +++- .../collective/ProcessGroupNCCL.cc | 226 +++++++++++++++++- .../distributed/collective/ProcessGroupNCCL.h | 49 +++- .../collective/ProcessGroupStream.cc | 84 +++++++ .../collective/ProcessGroupStream.h | 52 ++++ paddle/fluid/pybind/distributed_py.cc | 170 +++++++++++++ .../communication/stream/__init__.py | 4 +- .../communication/stream/all_reduce.py | 8 +- .../distributed/communication/stream/recv.py | 82 +++++++ .../distributed/communication/stream/send.py | 82 +++++++ .../tests/unittests/collective/CMakeLists.txt | 32 ++- ...mmunication_stream_sendrecv_api_dygraph.py | 68 ++++++ .../test_communication_stream_sendrecv_api.py | 50 ++++ .../tests/unittests/collective/testslist.csv | 5 +- 14 files changed, 922 insertions(+), 40 deletions(-) create mode 100644 python/paddle/distributed/communication/stream/recv.py create mode 100644 python/paddle/distributed/communication/stream/send.py create mode 100644 python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py create mode 100644 python/paddle/fluid/tests/unittests/collective/test_communication_stream_sendrecv_api.py diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 10b1686ddb85f..3db2464e59afd 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -134,24 +134,56 @@ class ProcessGroup { "ProcessGroup%s does not support send", GetBackendName())); } + virtual std::shared_ptr Send( + std::vector&, int, bool) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support send with sync_op flag", + GetBackendName())); + } + virtual std::shared_ptr Recv( - std::vector& tensors, int) { // NOLINT + std::vector&, int) { // NOLINT PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support receive", GetBackendName())); + "ProcessGroup%s does not support recv", GetBackendName())); } - virtual std::shared_ptr Send_Partial(phi::DenseTensor&, - int, - int, - int) { // NOLINT + virtual std::shared_ptr Recv( + std::vector&, int, bool) { // NOLINT PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support send", GetBackendName())); + "ProcessGroup%s does not support recv with sync_op flag", + GetBackendName())); + } + + virtual std::shared_ptr Send_Partial( + phi::DenseTensor&, // NOLINT + int, + int, + int) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support send_partial", GetBackendName())); + } + + virtual std::shared_ptr Send_Partial( + phi::DenseTensor&, int, int, int, bool) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support send_partial with sync_op flag", + GetBackendName())); } virtual std::shared_ptr Recv_Partial( - phi::DenseTensor& tensors, int, int, int) { // NOLINT + phi::DenseTensor&, // NOLINT + int, + int, + int) { PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support receive", GetBackendName())); + "ProcessGroup%s does not support recv_partial", GetBackendName())); + } + + virtual std::shared_ptr Recv_Partial( + phi::DenseTensor&, int, int, int, bool) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support recv_partial with sync_op flag", + GetBackendName())); } virtual std::shared_ptr AllGather( diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 239114ae6188c..368008d9cc0ce 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -51,6 +51,17 @@ std::shared_ptr ProcessGroupNCCL::CreateTask( places, rank, comm_type, inputs); } +std::shared_ptr ProcessGroupNCCL::CreateTask( + const std::vector& places, + int rank, + CommType comm_type, + const std::vector& inputs, + bool is_sync, + bool use_calc_stream) { + return std::make_shared( + places, rank, comm_type, inputs, is_sync, use_calc_stream); +} + ProcessGroupNCCL::NCCLTask::NCCLTask( const std::vector& places, int rank, @@ -264,10 +275,12 @@ std::shared_ptr ProcessGroupNCCL::Collective( auto& nccl_comms = places_to_ncclcomm_[key]; - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + if (!use_calc_stream) { + SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + } - auto task = std::make_shared( - places, rank_, comm_type, inputs, sync_op, use_calc_stream); + auto task = + CreateTask(places, rank_, comm_type, inputs, sync_op, use_calc_stream); platform::CUDADeviceGuard cuda_guard; @@ -406,6 +419,78 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, cuda_guard.SetDevice(places[0]); } +template +std::shared_ptr ProcessGroupNCCL::PointToPoint( + std::vector& tensors, + Fn fn, + int dst_rank, + CommType op_type, + bool sync_op, + bool use_calc_stream) { + const auto& places = GetPlaceList(tensors); + const auto& key = GetKeyFromPlaces(places); + + { + std::lock_guard lock(mutex_); + if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + CreateNCCLManagerCache(key, places); + } + } + + auto& nccl_comms = places_to_ncclcomm_[key]; + + if (!use_calc_stream) { + SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + } + + auto task = + CreateTask(places, rank_, op_type, tensors, sync_op, use_calc_stream); + + platform::CUDADeviceGuard cuda_guard; + + if (FLAGS_use_stream_safe_cuda_allocator) { + for (size_t i = 0; i < tensors.size(); ++i) { + cuda_guard.SetDevice(places[i]); + gpuStream_t nccl_stream; + if (use_calc_stream) { + nccl_stream = + static_cast( + platform::DeviceContextPool::Instance().Get(places[i])) + ->stream(); + } else { + nccl_stream = places_to_ctx_[key][i]->stream(); + } + memory::RecordStream(tensors[i].Holder(), nccl_stream); + } + } + + { + platform::NCCLGroupGuard nccl_guard; + for (size_t i = 0; i < tensors.size(); ++i) { + cuda_guard.SetDevice(places[i]); + gpuStream_t nccl_stream; + if (use_calc_stream) { + nccl_stream = + static_cast( + platform::DeviceContextPool::Instance().Get(places[i])) + ->stream(); + } else { + nccl_stream = places_to_ctx_[key][i]->stream(); + } + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + } + } + + if (!use_calc_stream) { + for (size_t i = 0; i < tensors.size(); ++i) { + cuda_guard.SetDevice(places[i]); + task->control_events_[i].Record(*places_to_ctx_[key][i]); + } + } + + return task; +} + template std::shared_ptr ProcessGroupNCCL::PointToPoint( std::vector& tensors, @@ -617,6 +702,34 @@ std::shared_ptr ProcessGroupNCCL::Send( return task; } +std::shared_ptr ProcessGroupNCCL::Send( + std::vector& tensors, + int dst_rank, + bool sync_op, + bool use_calc_stream) { + CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); + + auto task = PointToPoint( + tensors, + [&](phi::DenseTensor& input, + ncclComm_t comm, + const gpuStream_t& stream, + int dst_rank) { + return platform::dynload::ncclSend( + input.data(), + input.numel(), + platform::ToNCCLDataType(input.dtype()), + dst_rank, + comm, + stream); + }, + dst_rank, + CommType::SEND, + sync_op, + use_calc_stream); + return task; +} + std::shared_ptr ProcessGroupNCCL::Recv( std::vector& tensors, int src_rank) { CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); @@ -640,6 +753,34 @@ std::shared_ptr ProcessGroupNCCL::Recv( return task; } +std::shared_ptr ProcessGroupNCCL::Recv( + std::vector& tensors, + int src_rank, + bool sync_op, + bool use_calc_stream) { + CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); + + auto task = PointToPoint( + tensors, + [&](phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream, + int src_rank) { + return platform::dynload::ncclRecv( + output.data(), + output.numel(), + platform::ToNCCLDataType(output.dtype()), + src_rank, + comm, + stream); + }, + src_rank, + CommType::RECV, + sync_op, + use_calc_stream); + return task; +} + std::shared_ptr ProcessGroupNCCL::Send_Partial( phi::DenseTensor& tensors, int dst_rank, int offset, int length) { // CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); @@ -647,10 +788,8 @@ std::shared_ptr ProcessGroupNCCL::Send_Partial( phi::DenseTensor flatten_tensor; flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()}); - phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length); - - std::vector shared_tensors; - shared_tensors.push_back(shared_input); + std::vector shared_tensors{ + flatten_tensor.Slice(offset, offset + length)}; auto task = PointToPoint( shared_tensors, @@ -671,16 +810,49 @@ std::shared_ptr ProcessGroupNCCL::Send_Partial( return task; } +std::shared_ptr ProcessGroupNCCL::Send_Partial( + phi::DenseTensor& tensors, + int dst_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream) { + phi::DenseTensor flatten_tensor; + flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()}); + + std::vector shared_tensors{ + flatten_tensor.Slice(offset, offset + length)}; + + auto task = PointToPoint( + shared_tensors, + [&](phi::DenseTensor& input, + ncclComm_t comm, + const gpuStream_t& stream, + int dst_rank) { + return platform::dynload::ncclSend( + input.data(), + input.numel(), + platform::ToNCCLDataType(input.dtype()), + dst_rank, + comm, + stream); + }, + dst_rank, + CommType::SEND, + sync_op, + use_calc_stream); + return task; +} + std::shared_ptr ProcessGroupNCCL::Recv_Partial( phi::DenseTensor& tensors, int src_rank, int offset, int length) { // phi::DenseTensor shared_input = tensors.Slice(offset, offset+length); phi::DenseTensor flatten_tensor; flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()}); - phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length); - std::vector shared_tensors; - shared_tensors.push_back(shared_input); + std::vector shared_tensors{ + flatten_tensor.Slice(offset, offset + length)}; auto task = PointToPoint( shared_tensors, @@ -701,6 +873,40 @@ std::shared_ptr ProcessGroupNCCL::Recv_Partial( return task; } +std::shared_ptr ProcessGroupNCCL::Recv_Partial( + phi::DenseTensor& tensors, + int src_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream) { + phi::DenseTensor flatten_tensor; + flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()}); + + std::vector shared_tensors{ + flatten_tensor.Slice(offset, offset + length)}; + + auto task = PointToPoint( + shared_tensors, + [&](phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream, + int src_rank) { + return platform::dynload::ncclRecv( + output.data(), + output.numel(), + platform::ToNCCLDataType(output.dtype()), + src_rank, + comm, + stream); + }, + src_rank, + CommType::RECV, + sync_op, + use_calc_stream); + return task; +} + std::shared_ptr ProcessGroupNCCL::AllGather( std::vector& in_tensors, std::vector& out_tensors) { diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index e0e298e9113e9..0b8fa54cd337e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -60,7 +60,7 @@ class ProcessGroupNCCL : public ProcessGroupStream { int rank, CommType comm_type, const std::vector& inputs, - bool is_sync, + bool sync_op, bool use_calc_stream); bool IsCompleted(); @@ -122,19 +122,47 @@ class ProcessGroupNCCL : public ProcessGroupStream { std::shared_ptr Send( std::vector& tensors, int dst_rank) override; + std::shared_ptr Send( + std::vector& tensors, + int dst_rank, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Recv( std::vector& tensors, int src_rank) override; + std::shared_ptr Recv( + std::vector& tensors, + int src_rank, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Send_Partial(phi::DenseTensor& tensors, int dst_rank, int offset, int length) override; + std::shared_ptr Send_Partial( + phi::DenseTensor& tensors, + int dst_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Recv_Partial(phi::DenseTensor& tensors, int src_rank, int offset, int length) override; + std::shared_ptr Recv_Partial( + phi::DenseTensor& tensors, + int src_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr AllGather( std::vector& in_tensors, std::vector& out_tensors) override; @@ -180,9 +208,17 @@ class ProcessGroupNCCL : public ProcessGroupStream { virtual std::shared_ptr CreateTask( std::vector places, int rank, - CommType opType, + CommType op_type, const std::vector& inputs); + virtual std::shared_ptr CreateTask( + const std::vector& places, + int rank, + CommType op_type, + const std::vector& inputs, + bool sync_op, + bool use_calc_stream); + protected: std::shared_ptr store_; std::shared_ptr nccl_comm_; @@ -233,6 +269,15 @@ class ProcessGroupNCCL : public ProcessGroupStream { int dst_rank, CommType op_type); + template + std::shared_ptr PointToPoint( + std::vector& tensors, // NOLINT + Fn fn, + int dst_rank, + CommType op_type, + bool sync_op, + bool use_calc_stream); + void CreateNCCLManagerCache(const std::string& places_key, const std::vector& places); diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 9a20b8e6eaf79..51c8fe7bd9b1b 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -45,5 +45,89 @@ std::shared_ptr ProcessGroupStream::AllReduce( "ProcessGroup%s does not support do allreduce", GetBackendName())); } +std::shared_ptr ProcessGroupStream::Send( + std::vector& tensors, int dst_rank, bool sync_op) { + return Send(tensors, + dst_rank, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Send( + std::vector& tensors, + int dst_rank, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do send", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Send_Partial( + phi::DenseTensor& tensors, + int dst_rank, + int offset, + int length, + bool sync_op) { + return Send_Partial(tensors, + dst_rank, + offset, + length, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Send_Partial( + phi::DenseTensor& tensors, + int dst_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do send_partial", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Recv( + std::vector& tensors, int src_rank, bool sync_op) { + return Recv(tensors, + src_rank, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Recv( + std::vector& tensors, + int src_rank, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do recv", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Recv_Partial( + phi::DenseTensor& tensors, + int src_rank, + int offset, + int length, + bool sync_op) { + return Recv_Partial(tensors, + src_rank, + offset, + length, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Recv_Partial( + phi::DenseTensor& tensors, + int src_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do recv_partial", GetBackendName())); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index 81a05ee2416e0..4cd17ac72562e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -66,6 +66,58 @@ class ProcessGroupStream : public ProcessGroup { const AllreduceOptions& options, bool sync_op, bool use_calc_stream); + + std::shared_ptr Send( + std::vector& tensors, // NOLINT + int dst_rank, + bool sync_op) override; + + virtual std::shared_ptr Send( + std::vector& tensors, // NOLINT + int dst_rank, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Send_Partial( + phi::DenseTensor& tensors, // NOLINT + int dst_rank, + int offset, + int length, + bool sync_op) override; + + virtual std::shared_ptr Send_Partial( + phi::DenseTensor& tensors, // NOLINT + int dst_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Recv( + std::vector& tensors, // NOLINT + int src_rank, + bool sync_op) override; + + virtual std::shared_ptr Recv( + std::vector& tensors, // NOLINT + int src_rank, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Recv_Partial( + phi::DenseTensor& tensors, // NOLINT + int src_rank, + int offset, + int length, + bool sync_op) override; + + virtual std::shared_ptr Recv_Partial( + phi::DenseTensor& tensors, // NOLINT + int src_rank, + int offset, + int length, + bool sync_op, + bool use_calc_stream); }; } // namespace distributed diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 5a7e2355f64eb..8a434f42811a8 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -196,6 +196,23 @@ void BindDistributed(py::module *m) { py::arg("dst"), py::call_guard()) + .def( + "send", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int dst, + bool sync_op) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Send(tensors, dst, sync_op); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("sync_op"), + py::call_guard()) + .def( "send_partial", [](distributed::ProcessGroup &self, @@ -217,6 +234,30 @@ void BindDistributed(py::module *m) { py::arg("id"), py::call_guard()) + .def( + "send_partial", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int dst_rank, + int nranks, + int rank_id, + bool sync_op) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + int numel = (*dense).numel(); + int send_numel = numel / nranks; + int offset = send_numel * rank_id; + return self.Send_Partial( + *dense, dst_rank, offset, send_numel, sync_op); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("num"), + py::arg("id"), + py::arg("sync_op"), + py::call_guard()) + .def( "recv", [](distributed::ProcessGroup &self, @@ -232,6 +273,23 @@ void BindDistributed(py::module *m) { py::arg("src"), py::call_guard()) + .def( + "recv", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int src, + bool sync_op) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Recv(tensors, src, sync_op); + }, + py::arg("tensor"), + py::arg("src"), + py::arg("sync_op"), + py::call_guard()) + .def( "recv_partial", [](distributed::ProcessGroup &self, @@ -253,6 +311,30 @@ void BindDistributed(py::module *m) { py::arg("id"), py::call_guard()) + .def( + "recv_partial", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int src_rank, + int nranks, + int rank_id, + bool sync_op) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + int numel = (*dense).numel(); + int recv_numel = numel / nranks; + int offset = recv_numel * rank_id; + return self.Recv_Partial( + *dense, src_rank, offset, recv_numel, sync_op); + }, + py::arg("tensor"), + py::arg("src"), + py::arg("num"), + py::arg("id"), + py::arg("sync_op"), + py::call_guard()) + .def( "all_gather", [](distributed::ProcessGroup &self, @@ -427,6 +509,94 @@ void BindDistributed(py::module *m) { }, py::arg("tensor"), py::arg("op"), + py::call_guard()) + + .def( + "send_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_tensor, + int dst) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Send(tensors, + dst, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("tensor"), + py::arg("dst"), + py::call_guard()) + + .def( + "send_partial_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_tensor, + int dst_rank, + int nranks, + int rank_id) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + int numel = (*dense).numel(); + int send_numel = numel / nranks; + int offset = send_numel * rank_id; + return self.Send_Partial(*dense, + dst_rank, + offset, + send_numel, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("num"), + py::arg("id"), + py::call_guard()) + + .def( + "recv_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_tensor, + int src) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Recv(tensors, + src, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("tensor"), + py::arg("src"), + py::call_guard()) + + .def( + "recv_partial_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_tensor, + int src_rank, + int nranks, + int rank_id) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + int numel = (*dense).numel(); + int recv_numel = numel / nranks; + int offset = recv_numel * rank_id; + return self.Recv_Partial(*dense, + src_rank, + offset, + recv_numel, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("tensor"), + py::arg("src"), + py::arg("num"), + py::arg("id"), py::call_guard()); #if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) diff --git a/python/paddle/distributed/communication/stream/__init__.py b/python/paddle/distributed/communication/stream/__init__.py index 24194dd9fb1e2..3dd9f60b81295 100644 --- a/python/paddle/distributed/communication/stream/__init__.py +++ b/python/paddle/distributed/communication/stream/__init__.py @@ -13,5 +13,7 @@ # limitations under the License. from .all_reduce import all_reduce +from .send import send +from .recv import recv -__all__ = ["all_reduce"] +__all__ = ["all_reduce", "send", "recv"] diff --git a/python/paddle/distributed/communication/stream/all_reduce.py b/python/paddle/distributed/communication/stream/all_reduce.py index 6a0b622cf0dfe..f94422f4bd0a6 100644 --- a/python/paddle/distributed/communication/stream/all_reduce.py +++ b/python/paddle/distributed/communication/stream/all_reduce.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle.distributed.collective as collective import paddle.fluid.framework as framework -from ...collective import _get_default_group, _get_reduce_op, ReduceOp def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): - op_type = _get_reduce_op(op, "all_reduce") - group = _get_default_group() if group is None else group + op_type = collective._get_reduce_op(op, "all_reduce") + group = collective._get_default_group() if group is None else group if use_calc_stream: return group.process_group.allreduce_on_calc_stream(tensor, op_type) @@ -30,7 +30,7 @@ def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): def all_reduce(tensor, - op=ReduceOp.SUM, + op=collective.ReduceOp.SUM, group=None, sync_op=True, use_calc_stream=False): diff --git a/python/paddle/distributed/communication/stream/recv.py b/python/paddle/distributed/communication/stream/recv.py new file mode 100644 index 0000000000000..b225f64b8b4d2 --- /dev/null +++ b/python/paddle/distributed/communication/stream/recv.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed.collective as collective +import paddle.fluid.framework as framework + + +def _recv_in_dygraph(tensor, src, group, sync_op, use_calc_stream): + group = collective._get_default_group() if group is None else group + if use_calc_stream: + return group.process_group.recv_on_calc_stream(tensor, src) + + task = group.process_group.recv(tensor, src, sync_op) + if sync_op: + task.wait() + + return task + + +def recv(tensor, src=0, group=None, sync_op=True, use_calc_stream=False): + """ + + Receive a tensor from the source device. + + Args: + tensor (Tensor): The tensor to receive. Support float16, float32, float64, int32, int64, int8, uint8 or bool as its data type. + src (int, optional): Rank of the source device. If none is given, use `0` as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + local_rank = dist.get_rank() + if local_rank == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) + task = dist.stream.send(data, dst=1, sync_op=False) + else: + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + task = dist.stream.recv(data, src=0, sync_op=False) + task.wait() + out = data.numpy() + # [[4, 5, 6], [4, 5, 6] + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be True in sync op behavior.") + + if framework.in_dygraph_mode(): + return _recv_in_dygraph(tensor, src, group, sync_op, use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.recv is only supported in dygraph mode now.") diff --git a/python/paddle/distributed/communication/stream/send.py b/python/paddle/distributed/communication/stream/send.py new file mode 100644 index 0000000000000..fa052734c7ee7 --- /dev/null +++ b/python/paddle/distributed/communication/stream/send.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed.collective as collective +import paddle.fluid.framework as framework + + +def _send_in_dygraph(tensor, dst, group, sync_op, use_calc_stream): + group = collective._get_default_group() if group is None else group + if use_calc_stream: + return group.process_group.send_on_calc_stream(tensor, dst) + + task = group.process_group.send(tensor, dst, sync_op) + if sync_op: + task.wait() + + return task + + +def send(tensor, dst=0, group=None, sync_op=True, use_calc_stream=False): + """ + + Send a tensor to the destination device. + + Args: + tensor (Tensor): The tensor to send. Support float16, float32, float64, int32, int64, int8, uint8 or bool as its data type. + dst (int, optional): Rank of the destination device. If none is given, use `0` as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + local_rank = dist.get_rank() + if local_rank == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) + task = dist.stream.send(data, dst=1, sync_op=False) + else: + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + task = dist.stream.recv(data, src=0, sync_op=False) + task.wait() + out = data.numpy() + # [[4, 5, 6], [4, 5, 6] + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be True in sync op behavior.") + + if framework.in_dygraph_mode(): + return _send_in_dygraph(tensor, dst, group, sync_op, use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.send is only supported in dygraph mode now.") diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index 5a1a6df2dd7ec..55f4453b1ab08 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -268,17 +268,26 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( - test_eager_dist_api MODULES test_eager_dist_api ENVS - "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") - set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT "120" LABELS - "RUN_TYPE=DIST") + test_communication_stream_allreduce_api MODULES + test_communication_stream_allreduce_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_allreduce_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( - test_new_group_api MODULES test_new_group_api ENVS + test_communication_stream_sendrecv_api MODULES + test_communication_stream_sendrecv_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_sendrecv_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_eager_dist_api MODULES test_eager_dist_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") - set_tests_properties(test_new_group_api PROPERTIES TIMEOUT "120" LABELS - "RUN_TYPE=DIST") + set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT "120" LABELS + "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM @@ -298,11 +307,10 @@ if((WITH_GPU endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( - test_communication_stream_allreduce_api MODULES - test_communication_stream_allreduce_api ENVS - "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") - set_tests_properties(test_communication_stream_allreduce_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + test_new_group_api MODULES test_new_group_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_new_group_api PROPERTIES TIMEOUT "120" LABELS + "RUN_TYPE=DIST") endif() if((WITH_ROCM OR WITH_GPU) AND (LINUX)) bash_test_modules( diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py new file mode 100644 index 0000000000000..175e24c3d0d86 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.fluid as fluid +import test_collective_api_base as test_collective_base +import test_communication_api_base as test_base + + +class StreamSendRecvTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + test_data_list = [] + for seed in self._seeds: + test_data_list.append( + test_collective_base.create_test_data(shape=self._shape, + dtype=self._dtype, + seed=seed)) + + rank = dist.get_rank() + tensor = paddle.to_tensor(test_data_list[rank]) + if rank == 0: + task = dist.stream.send(tensor, + dst=1, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + else: + task = dist.stream.recv(tensor, + src=0, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + + result = test_data_list[0] + assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05) + + +if __name__ == "__main__": + StreamSendRecvTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_sendrecv_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_sendrecv_api.py new file mode 100644 index 0000000000000..9590519bc2e13 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_sendrecv_api.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import test_communication_api_base as test_base + + +class TestCommunicationStreamSendRecvAPI(test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamSendRecvAPI, self).setUp(num_of_devices=2, + timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_sendrecv_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case("communication_stream_sendrecv_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamSendRecvAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/testslist.csv b/python/paddle/fluid/tests/unittests/collective/testslist.csv index 16eb200565f73..b4ba281f45420 100644 --- a/python/paddle/fluid/tests/unittests/collective/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/testslist.csv @@ -32,8 +32,9 @@ test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_ test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_wait,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_sendrecv_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,