Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring NVLS interfaces #293

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <memory>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/nvls_device.hpp>
#include <string>
#include <vector>

Expand Down Expand Up @@ -455,76 +454,26 @@ class Connection {
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
};

class NvlsConnection {
public:
NvlsConnection(size_t bufferSize, int numDevices);
NvlsConnection(const std::vector<char>& data);
NvlsConnection() = delete;
std::vector<char> serialize();

// Everyone needs to synchronize after creating a NVLS connection before adding devices
void addDevice();
void addDevice(int cudaDeviceId);

struct DeviceMulticastPointer {
private:
std::shared_ptr<PhysicalCudaMemory<char>> deviceMem_;
std::shared_ptr<char> mcPtr_;
size_t bufferSize_;

public:
using DeviceHandle = DeviceMulticastPointerDeviceHandle;
DeviceMulticastPointer(std::shared_ptr<PhysicalCudaMemory<char>> deviceMem, std::shared_ptr<char> mcPtr,
size_t bufferSize)
: deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceHandle deviceHandle();
char* getDevicePtr();

friend class NvlsConnection;
};

std::shared_ptr<DeviceMulticastPointer> allocateAndBindCuda(size_t size);

/// The \p handle to the allocation (its lifetime is managed by the caller)
/// and the \p size of the allocation.
std::shared_ptr<char> bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size);

size_t getMultiCastMinGranularity();

private:
class Impl;
std::shared_ptr<Impl> pimpl_;
};

/// Used to configure an endpoint.
struct EndpointConfig {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;
// the recommended buffer size for NVLS, returned by cuMulticastGetGranularity
static const int DefaultNvlsBufferSize = (1 << 29);

Transport transport;
int ibMaxCqSize = DefaultMaxCqSize;
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

size_t nvlsBufferSize = DefaultNvlsBufferSize;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}

/// Constructor for NVLS explicitly
/// @param transport must be either NvlsRoot or NvlsNonRoot
/// @param nvlsBufferSize is the buffer to be alloced on each device
EndpointConfig(Transport transport, size_t nvlsBufferSize) : transport(transport), nvlsBufferSize(nvlsBufferSize) {}
};

/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
Expand Down Expand Up @@ -708,16 +657,6 @@ class Communicator {
/// to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);

/// Connect to NVLS on setup.
///
/// This function used to connect to NVLS on setup. NVLS collective using multicast operations to send/recv data.
/// Here we need to put all involved ranks into the collective group.
///
/// @param allRanks The ranks of all processes involved in the collective.
/// @param config The configuration for the local endpoint.
/// @return std::shared_ptr<NvlsConnection> A shared pointer to the NVLS connection.
std::shared_ptr<NvlsConnection> connectNvlsCollective(std::vector<int> allRanks, EndpointConfig config);

/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
Expand Down
72 changes: 72 additions & 0 deletions include/mscclpp/nvls.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef MSCCLPP_NVLS_HPP_
#define MSCCLPP_NVLS_HPP_

#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/nvls_device.hpp>

namespace mscclpp {

class NvlsConnection {
public:
NvlsConnection(size_t bufferSize, int numDevices);
NvlsConnection(const std::vector<char>& data);
NvlsConnection() = delete;
std::vector<char> serialize();

// the recommended buffer size for NVLS, returned by cuMulticastGetGranularity
static const int DefaultNvlsBufferSize = (1 << 29);

// Everyone needs to synchronize after creating a NVLS connection before adding devices
void addDevice();
void addDevice(int cudaDeviceId);

struct DeviceMulticastPointer {
private:
std::shared_ptr<PhysicalCudaMemory<char>> deviceMem_;
std::shared_ptr<char> mcPtr_;
size_t bufferSize_;

public:
using DeviceHandle = DeviceMulticastPointerDeviceHandle;
DeviceMulticastPointer(std::shared_ptr<PhysicalCudaMemory<char>> deviceMem, std::shared_ptr<char> mcPtr,
size_t bufferSize)
: deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceHandle deviceHandle();
char* getDevicePtr();

friend class NvlsConnection;
};

std::shared_ptr<DeviceMulticastPointer> allocateAndBindCuda(size_t size);

/// The \p handle to the allocation (its lifetime is managed by the caller)
/// and the \p size of the allocation.
std::shared_ptr<char> bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size);

size_t getMultiCastMinGranularity();

private:
class Impl;
std::shared_ptr<Impl> pimpl_;
};

class Communicator;

/// Connect to NVLS on setup.
///
/// This function used to connect to NVLS on setup. NVLS collective using multicast operations to send/recv data.
/// Here we need to put all involved ranks into the collective group.
///
/// @param comm The communicator.
/// @param allRanks The ranks of all processes involved in the collective.
/// @param config The configuration for the local endpoint.
/// @return std::shared_ptr<NvlsConnection> A shared pointer to the NVLS connection.
std::shared_ptr<NvlsConnection> connectNvlsCollective(std::shared_ptr<Communicator> comm, std::vector<int> allRanks,
size_t bufferSize = NvlsConnection::DefaultNvlsBufferSize);

} // namespace mscclpp

#endif // MSCCLPP_NVLS_HPP_
14 changes: 7 additions & 7 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ struct DeviceMulticastPointerDeviceHandle {
size_t bufferSize;

#if defined(MSCCLPP_DEVICE_CUDA)
template <int NElemPerThread = 4, typename TVaule = float4, typename T = float>
MSCCLPP_DEVICE_INLINE static void multimemLoad(TVaule& val, T* ptr) {
template <int NElemPerThread = 4, typename TValue = float4, typename T = float>
MSCCLPP_DEVICE_INLINE static void multimemLoad(TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<T, half2>::value) {
} else if constexpr (std::is_same<T, __half2>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
Expand All @@ -40,14 +40,14 @@ struct DeviceMulticastPointerDeviceHandle {
}
};

template <int NElemPerThread = 4, typename TVaule, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TVaule& val, T* ptr) {
template <int NElemPerThread = 4, typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<T, half2>::value) {
} else if constexpr (std::is_same<T, __half2>::value) {
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
Expand All @@ -60,4 +60,4 @@ struct DeviceMulticastPointerDeviceHandle {

} // namespace mscclpp

#endif // MSCCLPP_SEMAPHORE_DEVICE_HPP_
#endif // MSCCLPP_NVLS_DEVICE_HPP_
2 changes: 1 addition & 1 deletion python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def make_connection(
else:
endpoint = endpoints
if endpoint.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoint)
return connect_nvls_collective(self.communicator, all_ranks)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
Expand Down
22 changes: 2 additions & 20 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
extern void register_nvls(nb::module_& m);
extern void register_executor(nb::module_& m);

template <typename T>
Expand Down Expand Up @@ -128,24 +129,6 @@ void register_core(nb::module_& m) {
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);

nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
.def("get_device_ptr",
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);

nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);

nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("serialize", &Endpoint::serialize)
Expand All @@ -154,7 +137,6 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
.def_rw("transport", &EndpointConfig::transport)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
Expand Down Expand Up @@ -191,7 +173,6 @@ void register_core(nb::module_& m) {
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("connct_nvls_collective", &Communicator::connectNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
Expand All @@ -206,5 +187,6 @@ NB_MODULE(_mscclpp, m) {
register_utils(m);
register_core(m);
register_numa(m);
register_nvls(m);
register_executor(m);
}
38 changes: 38 additions & 0 deletions python/mscclpp/nvls_py.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <nanobind/nanobind.h>
#include <nanobind/operators.h>
#include <nanobind/stl/array.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>

#include <mscclpp/core.hpp>
#include <mscclpp/nvls.hpp>

namespace nb = nanobind;
using namespace mscclpp;

void register_nvls(nb::module_& m) {
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
.def("get_device_ptr",
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);

nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);

m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"),
nb::arg("bufferSize") = NvlsConnection::DefaultNvlsBufferSize);
}
39 changes: 0 additions & 39 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,45 +105,6 @@ MSCCLPP_API_CPP NonblockingFuture<std::shared_ptr<Connection>> Communicator::con
return NonblockingFuture<std::shared_ptr<Connection>>(connector->connectionPromise_.get_future());
}

MSCCLPP_API_CPP std::shared_ptr<NvlsConnection> Communicator::connectNvlsCollective(std::vector<int> allRanks,
EndpointConfig config) {
auto bootstrap = this->bootstrap();
int rank = bootstrap->getRank();
bool isRoot = false;
bool amongAllRanks = false;
int rootRank = allRanks[0];
for (auto nvlsRank : allRanks) {
if (nvlsRank == rank) amongAllRanks = true;
rootRank = std::min(rootRank, nvlsRank);
}
if (amongAllRanks == false) {
throw Error("rank is not among allRanks", ErrorCode::InvalidUsage);
}
if (rootRank == rank) isRoot = true;

std::shared_ptr<NvlsConnection> conn;
if (isRoot) {
conn = std::make_shared<NvlsConnection>(config.nvlsBufferSize, allRanks.size());
auto serialized = conn->serialize();
for (auto nvlsRank : allRanks) {
if (nvlsRank != rank) bootstrap->send(serialized, nvlsRank, 0);
}
} else {
std::vector<char> data;
bootstrap->recv(data, rootRank, 0);
conn = std::make_shared<NvlsConnection>(data);
}

// Now let's synchronize all ranks
bootstrap->groupBarrier(allRanks);
// now it is safe to add my device
conn->addDevice();

// sync here to make sure all ranks have added their devices
bootstrap->groupBarrier(allRanks);
return conn;
}

MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
return pimpl_->connectionInfos_.at(&connection).remoteRank;
}
Expand Down
Loading
Loading