From ed846f119b4ba3687c7a28240d98ae68a80f58ee Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 22 Apr 2024 21:55:46 +0000 Subject: [PATCH 1/2] Refactoring NVLS interfaces --- include/mscclpp/core.hpp | 61 ------------------------ include/mscclpp/nvls.hpp | 72 +++++++++++++++++++++++++++++ include/mscclpp/nvls_device.hpp | 14 +++--- python/mscclpp/comm.py | 2 +- python/mscclpp/core_py.cpp | 22 +-------- python/mscclpp/nvls_py.cpp | 38 +++++++++++++++ src/communicator.cc | 39 ---------------- src/{nvls_connection.cc => nvls.cc} | 41 ++++++++++++++++ 8 files changed, 161 insertions(+), 128 deletions(-) create mode 100644 include/mscclpp/nvls.hpp create mode 100644 python/mscclpp/nvls_py.cpp rename src/{nvls_connection.cc => nvls.cc} (88%) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 663d19655..ffbde7bf8 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -15,7 +15,6 @@ #include #include #include -#include #include #include @@ -455,55 +454,12 @@ class Connection { static std::shared_ptr getImpl(Endpoint& memory); }; -class NvlsConnection { - public: - NvlsConnection(size_t bufferSize, int numDevices); - NvlsConnection(const std::vector& data); - NvlsConnection() = delete; - std::vector serialize(); - - // Everyone needs to synchronize after creating a NVLS connection before adding devices - void addDevice(); - void addDevice(int cudaDeviceId); - - struct DeviceMulticastPointer { - private: - std::shared_ptr> deviceMem_; - std::shared_ptr mcPtr_; - size_t bufferSize_; - - public: - using DeviceHandle = DeviceMulticastPointerDeviceHandle; - DeviceMulticastPointer(std::shared_ptr> deviceMem, std::shared_ptr mcPtr, - size_t bufferSize) - : deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {} - DeviceHandle deviceHandle(); - char* getDevicePtr(); - - friend class NvlsConnection; - }; - - std::shared_ptr allocateAndBindCuda(size_t size); - - /// The \p handle to the allocation (its lifetime is managed by the caller) - /// and the \p size of the allocation. - std::shared_ptr bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size); - - size_t getMultiCastMinGranularity(); - - private: - class Impl; - std::shared_ptr pimpl_; -}; - /// Used to configure an endpoint. struct EndpointConfig { static const int DefaultMaxCqSize = 1024; static const int DefaultMaxCqPollNum = 1; static const int DefaultMaxSendWr = 8192; static const int DefaultMaxWrPerSend = 64; - // the recommended buffer size for NVLS, returned by cuMulticastGetGranularity - static const int DefaultNvlsBufferSize = (1 << 29); Transport transport; int ibMaxCqSize = DefaultMaxCqSize; @@ -511,8 +467,6 @@ struct EndpointConfig { int ibMaxSendWr = DefaultMaxSendWr; int ibMaxWrPerSend = DefaultMaxWrPerSend; - size_t nvlsBufferSize = DefaultNvlsBufferSize; - /// Default constructor. Sets transport to Transport::Unknown. EndpointConfig() : transport(Transport::Unknown) {} @@ -520,11 +474,6 @@ struct EndpointConfig { /// /// @param transport The transport to use. EndpointConfig(Transport transport) : transport(transport) {} - - /// Constructor for NVLS explicitly - /// @param transport must be either NvlsRoot or NvlsNonRoot - /// @param nvlsBufferSize is the buffer to be alloced on each device - EndpointConfig(Transport transport, size_t nvlsBufferSize) : transport(transport), nvlsBufferSize(nvlsBufferSize) {} }; /// Represents a context for communication. This provides a low-level interface for forming connections in use-cases @@ -708,16 +657,6 @@ class Communicator { /// to the connection. NonblockingFuture> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig); - /// Connect to NVLS on setup. - /// - /// This function used to connect to NVLS on setup. NVLS collective using multicast operations to send/recv data. - /// Here we need to put all involved ranks into the collective group. - /// - /// @param allRanks The ranks of all processes involved in the collective. - /// @param config The configuration for the local endpoint. - /// @return std::shared_ptr A shared pointer to the NVLS connection. - std::shared_ptr connectNvlsCollective(std::vector allRanks, EndpointConfig config); - /// Get the remote rank a connection is connected to. /// /// @param connection The connection to get the remote rank for. diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp new file mode 100644 index 000000000..b63be9d96 --- /dev/null +++ b/include/mscclpp/nvls.hpp @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_NVLS_HPP_ +#define MSCCLPP_NVLS_HPP_ + +#include +#include + +namespace mscclpp { + +class NvlsConnection { + public: + NvlsConnection(size_t bufferSize, int numDevices); + NvlsConnection(const std::vector& data); + NvlsConnection() = delete; + std::vector serialize(); + + // the recommended buffer size for NVLS, returned by cuMulticastGetGranularity + static const int DefaultNvlsBufferSize = (1 << 29); + + // Everyone needs to synchronize after creating a NVLS connection before adding devices + void addDevice(); + void addDevice(int cudaDeviceId); + + struct DeviceMulticastPointer { + private: + std::shared_ptr> deviceMem_; + std::shared_ptr mcPtr_; + size_t bufferSize_; + + public: + using DeviceHandle = DeviceMulticastPointerDeviceHandle; + DeviceMulticastPointer(std::shared_ptr> deviceMem, std::shared_ptr mcPtr, + size_t bufferSize) + : deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {} + DeviceHandle deviceHandle(); + char* getDevicePtr(); + + friend class NvlsConnection; + }; + + std::shared_ptr allocateAndBindCuda(size_t size); + + /// The \p handle to the allocation (its lifetime is managed by the caller) + /// and the \p size of the allocation. + std::shared_ptr bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size); + + size_t getMultiCastMinGranularity(); + + private: + class Impl; + std::shared_ptr pimpl_; +}; + +class Communicator; + +/// Connect to NVLS on setup. +/// +/// This function used to connect to NVLS on setup. NVLS collective using multicast operations to send/recv data. +/// Here we need to put all involved ranks into the collective group. +/// +/// @param comm The communicator. +/// @param allRanks The ranks of all processes involved in the collective. +/// @param config The configuration for the local endpoint. +/// @return std::shared_ptr A shared pointer to the NVLS connection. +std::shared_ptr connectNvlsCollective(std::shared_ptr comm, std::vector allRanks, + size_t bufferSize = NvlsConnection::DefaultNvlsBufferSize); + +} // namespace mscclpp + +#endif // MSCCLPP_NVLS_HPP_ diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index b04defbcf..0307f6772 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -22,15 +22,15 @@ struct DeviceMulticastPointerDeviceHandle { size_t bufferSize; #if defined(MSCCLPP_DEVICE_CUDA) - template - MSCCLPP_DEVICE_INLINE static void multimemLoad(TVaule& val, T* ptr) { + template + MSCCLPP_DEVICE_INLINE static void multimemLoad(TValue& val, T* ptr) { static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); if constexpr (std::is_same::value) { asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) : "l"(ptr) : "memory"); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) : "l"(ptr) @@ -40,14 +40,14 @@ struct DeviceMulticastPointerDeviceHandle { } }; - template - MSCCLPP_DEVICE_INLINE static void multimemStore(const TVaule& val, T* ptr) { + template + MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) { static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); if constexpr (std::is_same::value) { asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); @@ -60,4 +60,4 @@ struct DeviceMulticastPointerDeviceHandle { } // namespace mscclpp -#endif // MSCCLPP_SEMAPHORE_DEVICE_HPP_ +#endif // MSCCLPP_NVLS_DEVICE_HPP_ diff --git a/python/mscclpp/comm.py b/python/mscclpp/comm.py index 1cf9ebb41..4f0111d48 100644 --- a/python/mscclpp/comm.py +++ b/python/mscclpp/comm.py @@ -98,7 +98,7 @@ def make_connection( else: endpoint = endpoints if endpoint.transport == Transport.Nvls: - return self.communicator.connct_nvls_collective(all_ranks, endpoint) + return connect_nvls_collective(self.communicator, all_ranks) else: connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint) self.communicator.setup() diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 3b7f48023..8dc9df57b 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m); extern void register_semaphore(nb::module_& m); extern void register_utils(nb::module_& m); extern void register_numa(nb::module_& m); +extern void register_nvls(nb::module_& m); extern void register_executor(nb::module_& m); template @@ -128,24 +129,6 @@ void register_core(nb::module_& m) { .def("transport", &Connection::transport) .def("remote_transport", &Connection::remoteTransport); - nb::class_(m, "DeviceMulticastPointer") - .def("get_device_ptr", - [](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); }) - .def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle); - - nb::class_(m, "DeviceHandle") - .def(nb::init<>()) - .def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr) - .def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr) - .def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize) - .def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes { - return nb::bytes(reinterpret_cast(&self), sizeof(self)); - }); - - nb::class_(m, "NvlsConnection") - .def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda) - .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); - nb::class_(m, "Endpoint") .def("transport", &Endpoint::transport) .def("serialize", &Endpoint::serialize) @@ -154,7 +137,6 @@ void register_core(nb::module_& m) { nb::class_(m, "EndpointConfig") .def(nb::init<>()) .def(nb::init_implicit(), nb::arg("transport")) - .def(nb::init(), nb::arg("transport"), nb::arg("nvlsBufferSize")) .def_rw("transport", &EndpointConfig::transport) .def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize) .def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum) @@ -191,7 +173,6 @@ void register_core(nb::module_& m) { .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig")) - .def("connct_nvls_collective", &Communicator::connectNvlsCollective, nb::arg("allRanks"), nb::arg("config")) .def("remote_rank_of", &Communicator::remoteRankOf) .def("tag_of", &Communicator::tagOf) .def("setup", &Communicator::setup); @@ -206,5 +187,6 @@ NB_MODULE(_mscclpp, m) { register_utils(m); register_core(m); register_numa(m); + register_nvls(m); register_executor(m); } diff --git a/python/mscclpp/nvls_py.cpp b/python/mscclpp/nvls_py.cpp new file mode 100644 index 000000000..5b6232e09 --- /dev/null +++ b/python/mscclpp/nvls_py.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nb = nanobind; +using namespace mscclpp; + +void register_nvls(nb::module_& m) { + nb::class_(m, "DeviceMulticastPointer") + .def("get_device_ptr", + [](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); }) + .def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle); + + nb::class_(m, "DeviceHandle") + .def(nb::init<>()) + .def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr) + .def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr) + .def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize) + .def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); + + nb::class_(m, "NvlsConnection") + .def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda) + .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); + + m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"), + nb::arg("config") = NvlsConnection::DefaultNvlsBufferSize); +} diff --git a/src/communicator.cc b/src/communicator.cc index bd6a246d3..d2f0e6172 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -105,45 +105,6 @@ MSCCLPP_API_CPP NonblockingFuture> Communicator::con return NonblockingFuture>(connector->connectionPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connectNvlsCollective(std::vector allRanks, - EndpointConfig config) { - auto bootstrap = this->bootstrap(); - int rank = bootstrap->getRank(); - bool isRoot = false; - bool amongAllRanks = false; - int rootRank = allRanks[0]; - for (auto nvlsRank : allRanks) { - if (nvlsRank == rank) amongAllRanks = true; - rootRank = std::min(rootRank, nvlsRank); - } - if (amongAllRanks == false) { - throw Error("rank is not among allRanks", ErrorCode::InvalidUsage); - } - if (rootRank == rank) isRoot = true; - - std::shared_ptr conn; - if (isRoot) { - conn = std::make_shared(config.nvlsBufferSize, allRanks.size()); - auto serialized = conn->serialize(); - for (auto nvlsRank : allRanks) { - if (nvlsRank != rank) bootstrap->send(serialized, nvlsRank, 0); - } - } else { - std::vector data; - bootstrap->recv(data, rootRank, 0); - conn = std::make_shared(data); - } - - // Now let's synchronize all ranks - bootstrap->groupBarrier(allRanks); - // now it is safe to add my device - conn->addDevice(); - - // sync here to make sure all ranks have added their devices - bootstrap->groupBarrier(allRanks); - return conn; -} - MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) { return pimpl_->connectionInfos_.at(&connection).remoteRank; } diff --git a/src/nvls_connection.cc b/src/nvls.cc similarity index 88% rename from src/nvls_connection.cc rename to src/nvls.cc index 36b418eff..c4a7c7ec8 100644 --- a/src/nvls_connection.cc +++ b/src/nvls.cc @@ -6,8 +6,10 @@ #include #include +#include #include +#include "api.h" #include "debug.h" #include "endpoint.hpp" @@ -271,4 +273,43 @@ char* NvlsConnection::DeviceMulticastPointer::getDevicePtr() { return deviceMem_ size_t NvlsConnection::getMultiCastMinGranularity() { return pimpl_->getMinMcGran(); } +MSCCLPP_API_CPP std::shared_ptr connectNvlsCollective(std::shared_ptr comm, + std::vector allRanks, size_t bufferSize) { + auto bootstrap = comm->bootstrap(); + int rank = bootstrap->getRank(); + bool isRoot = false; + bool amongAllRanks = false; + int rootRank = allRanks[0]; + for (auto nvlsRank : allRanks) { + if (nvlsRank == rank) amongAllRanks = true; + rootRank = std::min(rootRank, nvlsRank); + } + if (amongAllRanks == false) { + throw Error("rank is not among allRanks", ErrorCode::InvalidUsage); + } + if (rootRank == rank) isRoot = true; + + std::shared_ptr conn; + if (isRoot) { + conn = std::make_shared(bufferSize, allRanks.size()); + auto serialized = conn->serialize(); + for (auto nvlsRank : allRanks) { + if (nvlsRank != rank) bootstrap->send(serialized, nvlsRank, 0); + } + } else { + std::vector data; + bootstrap->recv(data, rootRank, 0); + conn = std::make_shared(data); + } + + // Now let's synchronize all ranks + bootstrap->groupBarrier(allRanks); + // now it is safe to add my device + conn->addDevice(); + + // sync here to make sure all ranks have added their devices + bootstrap->groupBarrier(allRanks); + return conn; +} + } // namespace mscclpp From eef96f23b93a5b31e42223d742703896c79b1e17 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 23 Apr 2024 18:15:39 +0000 Subject: [PATCH 2/2] update an argument name --- python/mscclpp/nvls_py.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mscclpp/nvls_py.cpp b/python/mscclpp/nvls_py.cpp index 5b6232e09..819a7c6b0 100644 --- a/python/mscclpp/nvls_py.cpp +++ b/python/mscclpp/nvls_py.cpp @@ -34,5 +34,5 @@ void register_nvls(nb::module_& m) { .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"), - nb::arg("config") = NvlsConnection::DefaultNvlsBufferSize); + nb::arg("bufferSize") = NvlsConnection::DefaultNvlsBufferSize); }