Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add xpu multinodes training support with sparsemap #10

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,11 @@ class BoxWrapper {
boxps::MPICluster::Ins();
#ifdef PADDLE_WITH_XPU_KP
box_wrapper_kernel_ = std::make_unique<BoxWrapperKernel>();
use_xpu_sparse_map_ = false;
auto env_str = std::getenv("USE_XPU_SPARSE_MAP");
if (env_str != nullptr && (strcmp(env_str, "true") == 0 || strcmp(env_str, "1") == 0)) {
use_xpu_sparse_map_ = true;
}
#endif
#ifdef TRACE_PROFILE
// Client side to produce the tracepoints.
Expand Down Expand Up @@ -793,6 +798,7 @@ class BoxWrapper {
size_t input_table_dim_ = 0;
int gpu_num_ = GetDeviceCount();
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_sparse_map_;
std::vector<uint64_t> * fid2sign_map_;
std::unique_ptr<BoxWrapperKernel> box_wrapper_kernel_;
#endif
Expand Down
22 changes: 14 additions & 8 deletions paddle/fluid/framework/fleet/box_wrapper_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
TRACE_SCOPE_START("copy keys", xpu_wait(ctx_xpu->xpu_stream));
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
LoDTensor& total_keys_tensor = dev.keys_tensor;
uint32_t* total_keys;
uint64_t* total_keys;
if(use_l3_tensor) {
total_keys = reinterpret_cast<uint32_t*>(
total_keys_tensor.mutable_data<int32_t>({total_length, 1}, l3_place));
total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, l3_place));
} else {
total_keys = reinterpret_cast<uint32_t*>(
total_keys_tensor.mutable_data<int32_t>({total_length, 1}, place));
total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
}
int* key2slot = nullptr;
key2slot = reinterpret_cast<int*>(
Expand All @@ -347,9 +347,15 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
XPU_HOST_TO_DEVICE);

TRACE_SCOPE_START("CopyKeys", xpu_wait(ctx_xpu->xpu_stream));
box_wrapper_kernel_->CopyKeys(place, xpu_keys, total_keys, slot_lens,
if (use_xpu_sparse_map_) {
box_wrapper_kernel_->CopyKeys(place, xpu_keys, (unsigned long long *)total_keys, slot_lens,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length), key2slot);
} else {
box_wrapper_kernel_->CopyKeys(place, xpu_keys, (uint32_t *)total_keys, slot_lens,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length), key2slot);
}
VLOG(3) << "Begin call PullSparseXPU in BoxPS, dev: " << device_id
<< " len: " << total_length;
TRACE_SCOPE_END("CopyKeys", xpu_wait(ctx_xpu->xpu_stream));
Expand Down Expand Up @@ -533,8 +539,8 @@ void BoxWrapper::PushSparseGradCaseXPU(const paddle::platform::Place& place,
int64_t total_bytes = total_length * feature_push_size_;
void* total_grad_values_xpu =
dev.pull_push_tensor.mutable_data<void>(total_bytes, place);
uint32_t* total_keys =
reinterpret_cast<uint32_t*>(dev.keys_tensor.data<int32_t>());
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(dev.keys_tensor.data<int64_t>());
int* total_dims = reinterpret_cast<int*>(dev.dims_tensor.data<int>());
int slot_num = static_cast<int>(slot_lengths.size());

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ void CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint32_t* total_keys,
const int64_t* xpu_len, int slot_num,
int total_len, int* key2slot);
void CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, unsigned long long* total_keys,
const int64_t* xpu_len, int slot_num,
int total_len, int* key2slot);

void CopyForPull(
const paddle::platform::Place& place, uint64_t** xpu_keys,
Expand Down
85 changes: 72 additions & 13 deletions paddle/fluid/framework/fleet/box_wrapper_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/box_wrapper_kernel.h"
#include "paddle/fluid/platform/device_context.h"
#include "xpu/kernel/cluster_header.h" // NOLINT
// #include "xpu/kernel/debug.h" // NOLINT
//#include "xpu/kernel/debug.h" // NOLINT
#include "xpu/kernel/math.h" // NOLINT
#include "xpu/kernel/simd.h"

// The producer side.
#include <scalopus_tracing/tracing.h>
#include <scalopus_transport/transport_loopback.h>
Expand Down Expand Up @@ -69,7 +70,7 @@ static inline __device__ void xpu_sync_all(int group_mask = -1) {
__asm__("sync_group csr3");
}

__global__ void CopyKeysKernel(unsigned long long* src_keys,
__global__ void CopyKeysKernel_u32(unsigned long long* src_keys,
uint32_t* dest_total_keys,
const long long* len, int slot_num,
int total_len, int* key2slots) {
Expand All @@ -81,9 +82,6 @@ __global__ void CopyKeysKernel(unsigned long long* src_keys,
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();

// __local__ long long local_len[200];

//__local__ long long local_len[slot_num + 1];
__shared__ long long slot_lens[BoxWrapperKernel::MAX_SLOT_SIZE + 1];
__shared__ __global_ptr__ unsigned long long* slot_keys[BoxWrapperKernel::MAX_SLOT_SIZE];
for (int i = cid; i <= slot_num; i += ncores) {
Expand All @@ -95,15 +93,10 @@ __global__ void CopyKeysKernel(unsigned long long* src_keys,
}
}
sync_cluster();
//GM2LM(len, local_len, (slot_num + 1) * sizeof(long long));

//__global_ptr__ unsigned long long* local_keys[slot_num];
//GM2LM(src_keys, local_keys, slot_num * sizeof(__global_ptr__ unsigned long long*));

for (int i = thread_id; i < slot_num; i += nthreads) {
// max core local memory = 8KB
int slot_len = slot_lens[i + 1] - slot_lens[i];
// int read_len = min(slot_len, 1024);
int read_len = 100;
int dest_offset = slot_lens[i];
__local__ unsigned long long local_slot_keys[read_len];
Expand All @@ -120,8 +113,57 @@ __global__ void CopyKeysKernel(unsigned long long* src_keys,
}
mfence();
LM2GM(local_slot_uint32, key2slots + dest_offset + k, real_read_len * sizeof(int));
LM2GM(local_slot_keys_uint32, dest_total_keys + dest_offset + k,
real_read_len * sizeof(uint32_t));
LM2GM(local_slot_keys_uint32, dest_total_keys + dest_offset + k, real_read_len * sizeof(uint32_t));
}
}
}

__global__ void CopyKeysKernel_u64(unsigned long long* src_keys,
unsigned long long* dest_total_keys,
const long long* len, int slot_num,
int total_len, int* key2slots) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();

__shared__ long long slot_lens[BoxWrapperKernel::MAX_SLOT_SIZE + 1];
__shared__ __global_ptr__ unsigned long long* slot_keys[BoxWrapperKernel::MAX_SLOT_SIZE];
for (int i = cid; i <= slot_num; i += ncores) {
if (i <= slot_num) {
GM2SM(&len[i], &slot_lens[i], sizeof(long long));
}
if (i < slot_num) {
GM2SM(&src_keys[i], &slot_keys[i], sizeof(__global_ptr__ unsigned long long*));
}
}
sync_cluster();

for (int i = thread_id; i < slot_num; i += nthreads) {
// max core local memory = 8KB
int slot_len = slot_lens[i + 1] - slot_lens[i];
int read_len = 100;
int dest_offset = slot_lens[i];
__local__ unsigned long long local_slot_keys[read_len];
//__local__ uint32_t local_slot_keys_uint32[read_len];
//T * local_converted = (T *)local_slot_keys;
__local__ int local_slot_uint32[read_len];

for (int k = 0; k < slot_len; k += read_len) {
int real_read_len = min(read_len, slot_len - k);
GM2LM(slot_keys[i] + k, local_slot_keys,
real_read_len * sizeof(unsigned long long));
for (int m = 0; m < real_read_len; m++) {
//local_slot_keys_uint32[m] = (uint32_t)local_slot_keys[m];
//local_converted[m] = (T)(local_slot_keys[m]);
local_slot_uint32[m] = i;
}
mfence();
LM2GM(local_slot_uint32, key2slots + dest_offset + k, real_read_len * sizeof(int));
LM2GM(local_slot_keys, dest_total_keys + dest_offset + k, real_read_len * sizeof(unsigned long long));
}
}
}
Expand All @@ -139,7 +181,24 @@ void BoxWrapperKernel::CopyKeys(const paddle::platform::Place& place,
reinterpret_cast<unsigned long long*>(origin_keys);
const long long* c_len = (const long long*)gpu_len;
CHECK(slot_num <= BoxWrapperKernel::MAX_SLOT_SIZE);
CopyKeysKernel<<<2, 64, stream>>>(o_keys, total_keys, c_len, slot_num, total_len, key2slots);
CopyKeysKernel_u32<<<2, 64, stream>>>(o_keys, total_keys, c_len, slot_num, total_len, key2slots);
xpu_wait(stream);
}

void BoxWrapperKernel::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, unsigned long long* total_keys,
const int64_t* gpu_len, int slot_num,
int total_len, int* key2slots) {
XPUStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
->x_context()
->xpu_stream;
unsigned long long* o_keys =
reinterpret_cast<unsigned long long*>(origin_keys);
const long long* c_len = (const long long*)gpu_len;
CHECK(slot_num <= BoxWrapperKernel::MAX_SLOT_SIZE);
CopyKeysKernel_u64<<<2, 64, stream>>>(o_keys, total_keys, c_len, slot_num, total_len, key2slots);
xpu_wait(stream);
}

Expand Down