Skip to content

Commit

Permalink
[GPUPS]Optimize hbm (PaddlePaddle#11)
Browse files Browse the repository at this point in the history
* dymf tmp

* add dymf tmp

* local test change

* pull thread pool

* fix conflict

* delete unuse log

* local change for mirrow 0

* fix dymf

* code clean

* fix code clean

* code clean

* code clean

* fix dymf

* fix dymf

* add endpass optimize

* clean code

* fix endpass optimize

* fix

* fix

* optimize hbm&dymf cuRandState池化

* clean code

* change thread_num

Co-authored-by: yaoxuefeng6 <yaoxuefeng@baidu.com>
Co-authored-by: Thunderbrook <a754913769@163.com>
  • Loading branch information
3 people authored Jun 27, 2022
1 parent bdd38ba commit a5e3184
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 23 deletions.
23 changes: 21 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,23 @@ __global__ void dy_mf_update_kernel(Table* table,
}
}

template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
const typename Table::key_type* const keys,
const char* const grads, curandState* p_state, size_t len,
Sgd sgd, size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value((it.getter())->second, *cur, p_state[i]);
} else {
if(keys[i] != 0) printf("push miss key: %llu", keys[i]);
}
}
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
container_ = new TableContainer<KeyType, ValType>(capacity);
Expand Down Expand Up @@ -390,10 +407,12 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
if (len == 0) {
return;
}
auto state = CuRandState::get();
auto d_state = state->get(len, stream);
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;

dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
container_, d_keys, d_grads, d_state, len, sgd, push_grad_value_size_);
CuRandState::push(state, stream);
}

} // end namespace framework
Expand Down
32 changes: 32 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,38 @@ class Optimizer {
grad.show); // for local test
}
}

__device__ void dy_mf_update_value(ValType* ptr, const GradType& grad, curandState& state) {
ptr->slot = grad.slot;
ptr->show += grad.show;
ptr->clk += grad.clk;
ptr->delta_score += optimizer_config::nonclk_coeff * (grad.show - grad.clk) +
optimizer_config::clk_coeff * grad.clk;

update_lr(ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show);
// ptr->mf_dim = grad.mf_dim;

if (ptr->mf_size == 0) {
if (optimizer_config::mf_create_thresholds <=
optimizer_config::nonclk_coeff * (ptr->show - ptr->clk) +
optimizer_config::clk_coeff * ptr->clk) {
ptr->mf_size = ptr->mf_dim + 1;
ptr->mf[0] = 0;
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
//curandState state;
//curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < ptr->mf_dim; ++i) {
ptr->mf[i + 1] =
(curand_uniform(&state)) * optimizer_config::mf_initial_range;
}
}
} else {
update_mf(ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0], grad.mf_g,
grad.show); // for local test
}
}


};

} // end namespace framework
Expand Down
114 changes: 93 additions & 21 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -825,20 +825,76 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
HeterPs_->show_one_table(i);
}
};
auto build_dynamic_mf_func = [this, &gpu_task](int i, int j) {

// multi-thread process

auto build_dymf_mem_pool = [this, &gpu_task](int i, int j) {
this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_);
int mf_dim = this->index_dim_vec_[j];
VLOG(3) << "building table: " << i << "with mf dim: " << mf_dim;
size_t feature_value_size =
TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));
auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j];
size_t len = device_dim_keys.size();
CHECK(len == device_dim_ptrs.size());
this->mem_pools_[i * this->multi_mf_dim_ + j] =
new MemoryPool(len, feature_value_size);
this->mem_pools_[i * this->multi_mf_dim_ + j] = new MemoryPool(len, feature_value_size);
};

auto build_dymf_hbm_pool = [this, &gpu_task](int i, int j) {

auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
size_t len = device_dim_keys.size();
int mf_dim = this->index_dim_vec_[j];
size_t feature_value_size =
TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));

auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j];
platform::CUDADeviceGuard guard(resource_->dev_id(i));
this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool);
auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];

this->HeterPs_->build_ps(i, device_dim_keys.data(),
cur_pool->mem(), len, feature_value_size,
500000, 2);
if (device_dim_keys.size() > 0) {
VLOG(3) << "show table: " << i << " table kv size: " << device_dim_keys.size() << "dim: " << mf_dim << " len: " << len;
HeterPs_->show_one_table(i);
}
delete mem_pool;
};

int thread_num = 16;
auto build_dynamic_mf_func = [this, &gpu_task, thread_num](int i, int j, int z) {
int mf_dim = this->index_dim_vec_[j];
VLOG(3) << "building table: " << i << "with mf dim: " << mf_dim;

auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j];

size_t len = device_dim_keys.size();
CHECK(len == device_dim_ptrs.size());

auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j];
for (size_t k = 0; k < len; k++) {

// ============ add for multi-thread ================
int len_per_thread = len / thread_num;
int remain = len % thread_num;
int left = -1, right = -1;

int real_len = len_per_thread;
if (z < remain) real_len++;

if (z < remain) {
left = z * (len_per_thread + 1);
right = left + real_len;
} else {
left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread;
right = left + real_len;
}
// ============ add for multi-thread ================

for (int k = left; k < right; k++) {

FeatureValue* val = (FeatureValue*)(mem_pool->mem_address(k));
float* ptr_val = device_dim_ptrs[k]->data();
size_t dim = device_dim_ptrs[k]->size();
Expand Down Expand Up @@ -886,33 +942,49 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
}
}
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool);
auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len,
feature_value_size, 500000, 2);
if (device_dim_keys.size() > 0) {
VLOG(3) << "show table: " << i
<< " table kv size: " << device_dim_keys.size()
<< "dim: " << mf_dim << " len: " << len;
HeterPs_->show_one_table(i);
}
delete mem_pool;
};
if (!multi_mf_dim_) {
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(build_func, i);
}
for (std::thread& t : threads) {
t.join();
}
threads.clear();
} else {
threads.resize(device_num * multi_mf_dim_);
for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j);
threads[i + j * device_num] = std::thread(build_dymf_mem_pool, i, j);
}
}
}
for (std::thread& t : threads) {
t.join();
for (std::thread& t : threads) {
t.join();
}
threads.clear();
// multi-thread process
threads.resize(device_num * multi_mf_dim_ * thread_num);
for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
for (int k = 0; k < thread_num; k++) {
threads[(i + j * device_num) * thread_num + k] = std::thread(build_dynamic_mf_func, i, j, k);
}
}
}
for (std::thread& t : threads) {
t.join();
}
threads.clear();
threads.resize(device_num * multi_mf_dim_);
for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i + j * device_num] = std::thread(build_dymf_hbm_pool, i, j);
}
}
for (std::thread& t : threads) {
t.join();
}
threads.clear();
}
timeline.Pause();
VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec()
Expand Down

0 comments on commit a5e3184

Please sign in to comment.