Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

geo memory sparse table #39250

Merged
merged 3 commits into from
Jan 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 132 additions & 14 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() {
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_client_pull_dense");
profiler.register_profiler("pserver_client_pull_sparse");
profiler.register_profiler("pserver_client_pull_sparse_param");
profiler.register_profiler("pserver_client_pull_sparse_local");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_parse");
Expand Down Expand Up @@ -543,6 +544,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
return fut;
}

// for GEO
std::future<int32_t> BrpcPsClient::push_sparse_param(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
Expand All @@ -558,18 +560,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);

const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}

for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
Expand Down Expand Up @@ -1003,6 +995,120 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
return fut;
}

// for GEO
std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param");
size_t request_call_num = _server_channels.size();

auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);

for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}

auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}

auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;

// can remove sort&unique
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data), value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();

for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(), sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});

uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();

request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);

for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}

request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());

if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
}
return fut;
}

std::future<int32_t> BrpcPsClient::send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
Expand Down Expand Up @@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
std::string table_class;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).common().table_num();
var_shape = worker_param.downpour_table_param(i).common().table_dim();
table_class = worker_param.downpour_table_param(i).table_class();
break;
}
}
Expand All @@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}

auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
VLOG(2) << "recv_and_save_table: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// recv_and_save_table
if (table_class == "MemorySparseGeoTable") {
auto status =
pull_sparse_param(reinterpret_cast<float **>(save_vec.data()), table_id,
save_key.data(), save_key.size(), true);
status.wait();
} else {
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
}

// create lod tensor
std::shared_ptr<framework::Scope> scope;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient {
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);

virtual std::future<int32_t> print_table_stat(uint32_t table_id);

Expand Down
32 changes: 22 additions & 10 deletions paddle/fluid/distributed/ps/service/communicator/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,

bool training = true;

auto status = _worker_ptr->pull_sparse(
auto status = _worker_ptr->pull_sparse_param(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
Expand Down Expand Up @@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Push(sparse_ids_vec);
sparse_id_queues_.at(key)->Put(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
Expand All @@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,

for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) continue;
if (!ctx.is_sparse) {
parallel_task_nums_ += 1;
continue;
}
auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
Expand All @@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string, std::shared_ptr<BlockingQueue<
std::shared_ptr<std::vector<int64_t>>>>>(
std::pair<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var,
std::make_shared<
BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
send_queue_size_)));
paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
}
}

Expand Down Expand Up @@ -1242,8 +1244,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids =
sparse_id_queues_.at(send_varname)->Pop();
std::shared_ptr<std::vector<int64_t>> pop_ids = nullptr;
sparse_id_queues_.at(send_varname)->Get(pop_ids);
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
Expand All @@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse");
if (sparse_ids.size() == 0) {
return;
}
std::string param_name = SplitedGradToParam(varname);
VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id
Expand Down Expand Up @@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);

VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key "
<< sparse_ids[j] << " value[0] " << push_g_vec[j][0]
<< " value[-1] " << push_g_vec[j][dims1 - 1];
}

++_async_call_num;
Expand Down Expand Up @@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
cpu_ctx);

for (auto j = 0; j < static_cast<int>(keys.size()); ++j) {
VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j]
<< "value[0] " << values[j * dims1] << " value[-1] "
<< values[j * dims1 + dims1 - 1];
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License. */

#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
Expand Down Expand Up @@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;

std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
std::unordered_map<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_;
};

Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ class PSClient {
const uint64_t *keys, size_t num,
bool is_training) = 0;

virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}

virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values,
size_t table_id,
const uint64_t *keys,
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/distributed/ps/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)

cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)

cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)

target_link_libraries(table -fopenmp)
4 changes: 0 additions & 4 deletions paddle/fluid/distributed/ps/table/depends/geo_recorder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
#pragma once

#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace paddle {
Expand Down
Loading