Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite mscclpp-test with cpp style API #77

Merged
merged 11 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/mscclpp/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ struct DeviceChannel {
put(dst, offset, src, offset, size);
}

__forceinline__ __device__ void putDirect(void* dst, void* src, uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
uint32_t threadId, uint32_t numThreads) {
// assume the memory is aligned to 8 bytes
uint64_t* srcAddr = (uint64_t*)((char*)src + srcOffset);
uint64_t* dstAddr = (uint64_t*)((char*)dst + dstOffset);
uint64_t ele;
size_t nElem = size % sizeof(uint64_t) ? (size + sizeof(uint64_t)) / sizeof(uint64_t) : size / sizeof(uint64_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
// load to register first
ele = srcAddr[i];
dstAddr[i] = ele;
}
}

__forceinline__ __device__ void signalDirect() { epoch_.signalDirect(); }

__forceinline__ __device__ void signal() {
epochIncrement();
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
Expand Down Expand Up @@ -212,6 +228,9 @@ struct SimpleDeviceChannel {

SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}

SimpleDeviceChannel(DeviceChannel devChan, void* dstPtr, void* srcPtr)
: devChan_(devChan), srcPtr_(srcPtr), dstPtr_(dstPtr) {}

SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;

SimpleDeviceChannel& operator=(SimpleDeviceChannel& other) = default;
Expand All @@ -224,8 +243,14 @@ struct SimpleDeviceChannel {

__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }

__forceinline__ __device__ void putDirect(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
devChan_.putDirect(dstPtr_, srcPtr_, offset, offset, size, threadId, numThreads);
}

__forceinline__ __device__ void signal() { devChan_.signal(); }

__forceinline__ __device__ void signalDirect() { devChan_.signalDirect(); }

__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
}
Expand All @@ -251,6 +276,10 @@ struct SimpleDeviceChannel {
DeviceChannel devChan_;
MemoryId dst_;
MemoryId src_;

// these are used for direct copy
void* srcPtr_;
void* dstPtr_;
};

} // namespace channel
Expand Down
45 changes: 45 additions & 0 deletions include/mscclpp/concurrency.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef MSCCLPP_CONCURRENCY_HPP_
#define MSCCLPP_CONCURRENCY_HPP_

namespace mscclpp {
struct DeviceSyncer {
public:
DeviceSyncer() = default;
~DeviceSyncer() = default;

#ifdef __CUDACC__
// Synchronize multiple thread blocks inside a kernel. Guarantee that all
// previous work of all threads in cooperating blocks is finished.
__forceinline__ __device__ void sync(int blockNum) {
int maxOldCnt = blockNum - 1;
__syncthreads();
if (threadIdx.x == 0) {
int tmpIsAdd = isAdd_ ^ 1;
if (tmpIsAdd) {
if (atomicAdd(&count_, 1) == maxOldCnt) {
flag_ = 1;
}
while (!flag_) {
}
} else {
if (atomicSub(&count_, 1) == 1) {
flag_ = 0;
}
while (flag_) {
}
}
isAdd_ = tmpIsAdd;
}
// We need sync here because only a single thread is checking whether
// the flag is flipped.
__syncthreads();
}
#endif

private:
volatile int flag_;
int count_;
int isAdd_;
};
} // namespace mscclpp
#endif // MSCCLPP_CONCURRENCY_HPP_
11 changes: 10 additions & 1 deletion include/mscclpp/epoch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class BaseEpoch {
private:
std::shared_ptr<Connection> connection_;
RegisteredMemory localEpochIdsRegMem_;
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;

protected:
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;
std::unique_ptr<EpochIds, Deleter<EpochIds>> epochIds_;
std::unique_ptr<uint64_t, Deleter<uint64_t>> expectedInboundEpochId_;

Expand Down Expand Up @@ -56,9 +56,18 @@ class DeviceEpoch : BaseEpoch<CudaDeleter> {
}

__forceinline__ __device__ void epochIncrement() { *(volatile uint64_t*)&(epochIds->outbound) += 1; }

__forceinline__ __device__ void signalDirect() {
// This fence ensures that the writes from a preceding putDirect() are visible on the peer GPU before the
// incremented epoch id is visible.
__threadfence_system();
epochIncrement();
*(volatile uint64_t*)&(remoteEpochIds->inboundReplica) = epochIds->outbound;
}
#endif // __CUDACC__

EpochIds* epochIds;
EpochIds* remoteEpochIds;
uint64_t* expectedInboundEpochId;
};

Expand Down
17 changes: 16 additions & 1 deletion src/include/utils.hpp → include/mscclpp/utils.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#ifndef MSCCLPP_UTILS_HPP_
#define MSCCLPP_UTILS_HPP_

#include <stdio.h>
#include <unistd.h>

#include <chrono>
#include <cstdio>
#include <cstring>
#include <string>

namespace mscclpp {

Expand Down Expand Up @@ -35,6 +38,18 @@ struct ScopedTimer {
~ScopedTimer() { timer.print(name); }
};

inline std::string getHostName(int maxlen, const char delim) {
std::string hostname(maxlen + 1, '\0');
if (gethostname(const_cast<char*>(hostname.data()), maxlen) != 0) {
std::strncpy(const_cast<char*>(hostname.data()), "unknown", maxlen);
throw;
}
int i = 0;
while ((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1)) i++;
hostname[i] = '\0';
return hostname;
}

} // namespace mscclpp

#endif // MSCCLPP_UTILS_HPP_
2 changes: 1 addition & 1 deletion src/connection.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "connection.hpp"

#include <algorithm>
#include <mscclpp/utils.hpp>

#include "checks_internal.hpp"
#include "debug.h"
#include "infiniband/verbs.h"
#include "npkit/npkit.h"
#include "registered_memory.hpp"
#include "utils.hpp"

namespace mscclpp {

Expand Down
1 change: 1 addition & 0 deletions src/epoch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ MSCCLPP_API_CPP void DeviceEpoch::signal() { BaseEpoch::signal(); }

MSCCLPP_API_CPP DeviceEpoch::DeviceHandle DeviceEpoch::deviceHandle() {
DeviceEpoch::DeviceHandle device;
device.remoteEpochIds = reinterpret_cast<EpochIds*>(remoteEpochIdsRegMem_.get().data());
device.epochIds = epochIds_.get();
device.expectedInboundEpochId = expectedInboundEpochId_.get();
return device;
Expand Down
2 changes: 1 addition & 1 deletion src/proxy.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <atomic>
#include <mscclpp/core.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/utils.hpp>
#include <thread>

#include "api.h"
#include "utils.h"
#include "utils.hpp"

namespace mscclpp {

Expand Down
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ add_executable(unit_tests)
target_link_libraries(unit_tests GTest::gtest_main GTest::gmock_main mscclpp CUDA::cudart CUDA::cuda_driver)
add_subdirectory(unit) # This adds the sources to the mscclpp target
gtest_discover_tests(unit_tests DISCOVERY_MODE PRE_TEST)

# Msccclpp_test
add_subdirectory(mscclpp-test)
Loading