Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… backend_to_place
  • Loading branch information
zyfncg committed Mar 21, 2022
2 parents 9e7b8c0 + 56c43cc commit 5f6e65c
Show file tree
Hide file tree
Showing 233 changed files with 10,628 additions and 5,109 deletions.
61 changes: 61 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}

std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}

std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
Expand All @@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}

std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}

std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
}
Expand Down Expand Up @@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}

std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
}
}
}

std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
}
}
}

std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;

std::future<int32_t> Load(const LoadSaveContext &load_context) override;

std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;

std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;

virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;

std::future<int32_t> clear() override;

std::future<int32_t> clear(uint32_t table_id) override;
Expand Down Expand Up @@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const uint64_t *keys,
size_t num, bool is_training);

virtual std::future<int32_t> Pull(RequestContext &pull_context) override;

virtual std::future<int32_t> Push(RequestContext &push_context) override;

virtual std::future<int32_t> print_table_stat(uint32_t table_id);

virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server.Join();
return 0;
}
virtual int32_t port();
int32_t port();

private:
virtual int32_t initialize();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server.Join();
return 0;
}
virtual int32_t port();
int32_t port();

std::condition_variable *export_cv() { return &cv_; }

Expand Down
46 changes: 46 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"

namespace paddle {
Expand Down Expand Up @@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};

struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};

enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };

enum TrainingPhase { Init = 0, Train = 1, Save = 2 };

// enum ValueType {
// Sparse = 0,
// Dense = 1
// };

struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};

struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};

class PSClient {
public:
PSClient() {}
Expand Down Expand Up @@ -86,13 +122,18 @@ class PSClient {
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;

// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;

virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;

// 清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
Expand All @@ -107,6 +148,8 @@ class PSClient {
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留

virtual std::future<int32_t> Push(RequestContext &push_context) = 0;

// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
Expand All @@ -117,6 +160,9 @@ class PSClient {
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;

virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;

// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
Expand Down
73 changes: 73 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ ::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
return done();
}

std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}

::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) {
// TODO
Expand All @@ -74,6 +87,21 @@ ::std::future<int32_t> PsLocalClient::save(uint32_t table_id,
return done();
}

::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}

::std::future<int32_t> PsLocalClient::clear() {
// TODO
return done();
Expand All @@ -93,6 +121,51 @@ ::std::future<int32_t> PsLocalClient::stop_server() {
return done();
}

::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
}
}

::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}

::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num,
size_t table_id) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/ps/service/ps_local_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;

virtual ::std::future<int32_t> save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;

virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override;
Expand All @@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id);

virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;

virtual ::std::future<int32_t> Push(RequestContext& push_context) override;

virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id);

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_local_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/distributed/ps/service/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size();

const auto &downpour_param = _config.downpour_server_param();
Expand Down
15 changes: 0 additions & 15 deletions paddle/fluid/distributed/ps/service/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ class PSServer {
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});

// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;

virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;

Expand All @@ -94,15 +89,6 @@ class PSServer {
return &_table_map;
}

typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}

paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;

protected:
virtual int32_t initialize() = 0;

Expand All @@ -111,7 +97,6 @@ class PSServer {
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;

protected:
std::shared_ptr<framework::Scope> scope_;
Expand Down
Loading

0 comments on commit 5f6e65c

Please sign in to comment.