Skip to content

Commit

Permalink
geo memory sparse table (#39250)
Browse files Browse the repository at this point in the history
* geo depends

* add memory geo table

* fix
  • Loading branch information
zhaocaibei123 authored Jan 30, 2022
1 parent bafea65 commit 9b3b53b
Show file tree
Hide file tree
Showing 13 changed files with 604 additions and 33 deletions.
146 changes: 132 additions & 14 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() {
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_client_pull_dense");
profiler.register_profiler("pserver_client_pull_sparse");
profiler.register_profiler("pserver_client_pull_sparse_param");
profiler.register_profiler("pserver_client_pull_sparse_local");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_parse");
Expand Down Expand Up @@ -543,6 +544,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
return fut;
}

// for GEO
std::future<int32_t> BrpcPsClient::push_sparse_param(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
Expand All @@ -558,18 +560,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);

const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}

for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
Expand Down Expand Up @@ -1003,6 +995,120 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
return fut;
}

// for GEO
std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param");
size_t request_call_num = _server_channels.size();

auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);

for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}

auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}

auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;

// can remove sort&unique
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data), value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();

for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(), sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});

uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();

request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);

for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}

request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());

if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
}
return fut;
}

std::future<int32_t> BrpcPsClient::send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
Expand Down Expand Up @@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
std::string table_class;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).common().table_num();
var_shape = worker_param.downpour_table_param(i).common().table_dim();
table_class = worker_param.downpour_table_param(i).table_class();
break;
}
}
Expand All @@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}

auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
VLOG(2) << "recv_and_save_table: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// recv_and_save_table
if (table_class == "MemorySparseGeoTable") {
auto status =
pull_sparse_param(reinterpret_cast<float **>(save_vec.data()), table_id,
save_key.data(), save_key.size(), true);
status.wait();
} else {
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
}

// create lod tensor
std::shared_ptr<framework::Scope> scope;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient {
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);

virtual std::future<int32_t> print_table_stat(uint32_t table_id);

Expand Down
32 changes: 22 additions & 10 deletions paddle/fluid/distributed/ps/service/communicator/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,

bool training = true;

auto status = _worker_ptr->pull_sparse(
auto status = _worker_ptr->pull_sparse_param(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
Expand Down Expand Up @@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Push(sparse_ids_vec);
sparse_id_queues_.at(key)->Put(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
Expand All @@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,

for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) continue;
if (!ctx.is_sparse) {
parallel_task_nums_ += 1;
continue;
}
auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
Expand All @@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string, std::shared_ptr<BlockingQueue<
std::shared_ptr<std::vector<int64_t>>>>>(
std::pair<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var,
std::make_shared<
BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
send_queue_size_)));
paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
}
}

Expand Down Expand Up @@ -1242,8 +1244,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids =
sparse_id_queues_.at(send_varname)->Pop();
std::shared_ptr<std::vector<int64_t>> pop_ids = nullptr;
sparse_id_queues_.at(send_varname)->Get(pop_ids);
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
Expand All @@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse");
if (sparse_ids.size() == 0) {
return;
}
std::string param_name = SplitedGradToParam(varname);
VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id
Expand Down Expand Up @@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);

VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key "
<< sparse_ids[j] << " value[0] " << push_g_vec[j][0]
<< " value[-1] " << push_g_vec[j][dims1 - 1];
}

++_async_call_num;
Expand Down Expand Up @@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
cpu_ctx);

for (auto j = 0; j < static_cast<int>(keys.size()); ++j) {
VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j]
<< "value[0] " << values[j * dims1] << " value[-1] "
<< values[j * dims1 + dims1 - 1];
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License. */

#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
Expand Down Expand Up @@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;

std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
std::unordered_map<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_;
};

Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ class PSClient {
const uint64_t *keys, size_t num,
bool is_training) = 0;

virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}

virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values,
size_t table_id,
const uint64_t *keys,
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/distributed/ps/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)

cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)

cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)

target_link_libraries(table -fopenmp)
4 changes: 0 additions & 4 deletions paddle/fluid/distributed/ps/table/depends/geo_recorder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
#pragma once

#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace paddle {
Expand Down
Loading

0 comments on commit 9b3b53b

Please sign in to comment.