Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mod base #40702

Merged
merged 1 commit into from
Mar 21, 2022
Merged

mod base #40702

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}

std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}

std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
Expand All @@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}

std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}

std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
}
Expand Down Expand Up @@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}

std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
}
}
}

std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
}
}
}

std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;

std::future<int32_t> Load(const LoadSaveContext &load_context) override;

std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;

std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;

virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;

std::future<int32_t> clear() override;

std::future<int32_t> clear(uint32_t table_id) override;
Expand Down Expand Up @@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const uint64_t *keys,
size_t num, bool is_training);

virtual std::future<int32_t> Pull(RequestContext &pull_context) override;

virtual std::future<int32_t> Push(RequestContext &push_context) override;

virtual std::future<int32_t> print_table_stat(uint32_t table_id);

virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server.Join();
return 0;
}
virtual int32_t port();
int32_t port();

private:
virtual int32_t initialize();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server.Join();
return 0;
}
virtual int32_t port();
int32_t port();

std::condition_variable *export_cv() { return &cv_; }

Expand Down
46 changes: 46 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"

namespace paddle {
Expand Down Expand Up @@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};

struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};

enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };

enum TrainingPhase { Init = 0, Train = 1, Save = 2 };

// enum ValueType {
// Sparse = 0,
// Dense = 1
// };

struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};

struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};

class PSClient {
public:
PSClient() {}
Expand Down Expand Up @@ -86,13 +122,18 @@ class PSClient {
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;

// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;

virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;

// 清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
Expand All @@ -107,6 +148,8 @@ class PSClient {
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留

virtual std::future<int32_t> Push(RequestContext &push_context) = 0;

// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
Expand All @@ -117,6 +160,9 @@ class PSClient {
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;

virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;

// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
Expand Down
73 changes: 73 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ ::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
return done();
}

std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}

::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) {
// TODO
Expand All @@ -74,6 +87,21 @@ ::std::future<int32_t> PsLocalClient::save(uint32_t table_id,
return done();
}

::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}

::std::future<int32_t> PsLocalClient::clear() {
// TODO
return done();
Expand All @@ -93,6 +121,51 @@ ::std::future<int32_t> PsLocalClient::stop_server() {
return done();
}

::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
}
}

::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}

::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num,
size_t table_id) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_local_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;

virtual ::std::future<int32_t> save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;

virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override;
Expand All @@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id);

virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;

virtual ::std::future<int32_t> Push(RequestContext& push_context) override;

virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id);

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_local_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/distributed/ps/service/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size();

const auto &downpour_param = _config.downpour_server_param();
Expand Down
15 changes: 0 additions & 15 deletions paddle/fluid/distributed/ps/service/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ class PSServer {
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});

// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;

virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;

Expand All @@ -94,15 +89,6 @@ class PSServer {
return &_table_map;
}

typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}

paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;

protected:
virtual int32_t initialize() = 0;

Expand All @@ -111,7 +97,6 @@ class PSServer {
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;

protected:
std::shared_ptr<framework::Scope> scope_;
Expand Down
Loading