diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index b8ccd8e744dab..f86b4b706b3e2 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -414,6 +414,16 @@ std::future BrpcPsClient::load(uint32_t table_id, return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); } +std::future BrpcPsClient::Load(const LoadSaveContext &load_context) { + if (load_context.table_id < 0) { + return send_cmd(-1, PS_LOAD_ALL_TABLE, + {load_context.epoch, load_context.mode}); + } else { + return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE, + {load_context.epoch, load_context.mode}); + } +} + std::future BrpcPsClient::save(const std::string &epoch, const std::string &mode) { VLOG(1) << "BrpcPsClient::save path " << epoch; @@ -427,6 +437,19 @@ std::future BrpcPsClient::save(uint32_t table_id, return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } +std::future BrpcPsClient::Save(const LoadSaveContext &save_context) { + if (save_context.table_id < 0) { + VLOG(1) << "BrpcPsClient::save path " << save_context.epoch; + return send_save_cmd(-1, PS_SAVE_ALL_TABLE, + {save_context.epoch, save_context.mode}); + } else { + VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch + << " table_id " << save_context.table_id; + return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE, + {save_context.epoch, save_context.mode}); + } +} + std::future BrpcPsClient::clear() { return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); } @@ -505,6 +528,44 @@ std::future BrpcPsClient::barrier(size_t table_id, return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); } +std::future BrpcPsClient::Pull(RequestContext &pull_context) { + if (pull_context.value_type == Dense) { // pull dense + Region *dense_region = + reinterpret_cast(pull_context.dense_values); + pull_dense(dense_region, pull_context.num, pull_context.table); + } else { // pull sparse + uint64_t *keys = reinterpret_cast(pull_context.keys); + float **select_values = + reinterpret_cast(pull_context.sparse_values); + size_t table_id = pull_context.table; + size_t num = pull_context.num; + bool is_training = pull_context.is_training; + if (pull_context.training_mode == Geo) { // for geo + pull_sparse_param(select_values, table_id, keys, num, is_training); + } else if (pull_context.training_mode == Async) { // for async + pull_sparse(select_values, table_id, keys, num, is_training); + } + } +} + +std::future BrpcPsClient::Push(RequestContext &push_context) { + if (push_context.value_type == Dense) { // push dense + const Region *dense_region = push_context.push_context.push_dense_values; + push_dense(dense_region, push_context.num, push_context.table); + } else { // push sparse + size_t table_id = push_context.table; + size_t num = push_context.num; + bool is_training = push_context.is_training; + if (push_context.training_mode == Geo) { // for geo + // TODO(zhaocaibei) + } else if (push_context.training_mode == Async) { // for async + const uint64_t *keys = push_context.push_context.keys; + const float **update_values = push_context.push_context.push_values; + push_sparse(table_id, keys, update_values, num); + } + } +} + std::future BrpcPsClient::pull_geo_param(size_t table_id, std::vector *values, std::vector *keys, diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index 59ed59933db86..8b0cb0741b400 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient { std::future load(uint32_t table_id, const std::string &epoch, const std::string &mode) override; + std::future Load(const LoadSaveContext &load_context) override; + std::future save(const std::string &epoch, const std::string &mode) override; std::future save(uint32_t table_id, const std::string &epoch, const std::string &mode) override; + virtual std::future Save( + const LoadSaveContext &save_context) override; + std::future clear() override; std::future clear(uint32_t table_id) override; @@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient { const uint64_t *keys, size_t num, bool is_training); + virtual std::future Pull(RequestContext &pull_context) override; + + virtual std::future Push(RequestContext &push_context) override; + virtual std::future print_table_stat(uint32_t table_id); virtual std::future barrier(size_t table_id, uint32_t barrier_type); diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.h b/paddle/fluid/distributed/ps/service/brpc_ps_server.h index 4310c247438ce..d81a3a5df07f1 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.h @@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer { _server.Join(); return 0; } - virtual int32_t port(); + int32_t port(); private: virtual int32_t initialize(); diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.h b/paddle/fluid/distributed/ps/service/graph_brpc_server.h index aee0190850753..a978d97b296b0 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.h @@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer { _server.Join(); return 0; } - virtual int32_t port(); + int32_t port(); std::condition_variable *export_cv() { return &cv_; } diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 21719fbdbf1d6..8a2bfbe31602b 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -26,6 +26,7 @@ #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h" +#include "paddle/fluid/distributed/ps/table/table.h" #include "paddle/fluid/platform/timer.h" namespace paddle { @@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure { std::vector>> _promises; }; +struct LoadSaveContext { + int table_id; + std::string epoch; + std::string mode; +}; + +enum TrainingMode { Async = 0, Sync = 1, Geo = 3 }; + +enum TrainingPhase { Init = 0, Train = 1, Save = 2 }; + +// enum ValueType { +// Sparse = 0, +// Dense = 1 +// }; + +struct PushContext { + const uint64_t *keys; + const float **push_values; + const Region *push_dense_values; +}; + +struct RequestContext { + int table; + TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync + TrainingPhase training_phase; // 1 for init, 2 for train + ValueType value_type; // 1 for sparse, 2 for dense + void *keys; + void **sparse_values; // for sparse values + Region *dense_values; // for dense values + PushContext push_context; + size_t num; + bool is_training; + void *callback; +}; + class PSClient { public: PSClient() {} @@ -86,6 +122,9 @@ class PSClient { // 指定table数据load virtual std::future load(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; + // context配置load选项 + virtual std::future Load(const LoadSaveContext &load_context) = 0; + // 全量table数据save value_accessor根据mode,可能有不同的save条件 virtual std::future save(const std::string &epoch, const std::string &mode) = 0; @@ -93,6 +132,8 @@ class PSClient { virtual std::future save(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; + virtual std::future Save(const LoadSaveContext &save_context) = 0; + // 清空table数据 virtual std::future clear() = 0; virtual std::future clear(uint32_t table_id) = 0; @@ -107,6 +148,8 @@ class PSClient { virtual std::future pull_dense(Region *regions, size_t region_num, size_t table_id) = 0; // 保留 + virtual std::future Push(RequestContext &push_context) = 0; + // firstly push dense param for parameter server // this is neccessary because dense weight initialized in trainer on cold // start @@ -117,6 +160,9 @@ class PSClient { virtual std::future push_dense(const Region *regions, size_t region_num, size_t table_id) = 0; + + virtual std::future Pull(RequestContext &pull_context) = 0; + // 使用keys进行pull请求,结果填充values // keys和values的个数均为num个,每个value占用select_size空间 // future结束前keys和values缓冲区不能再次使用 diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index 972cce135f189..9e364b6d3ed7a 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -56,6 +56,19 @@ ::std::future PsLocalClient::load(uint32_t table_id, return done(); } +std::future PsLocalClient::Load(const LoadSaveContext& load_context) { + if (load_context.table_id < 0) { + for (auto& it : _table_map) { + load(it.first, load_context.epoch, load_context.mode); + } + return done(); + } else { + auto* table_ptr = table(load_context.table_id); + table_ptr->load(load_context.epoch, load_context.mode); + return done(); + } +} + ::std::future PsLocalClient::save(const std::string& epoch, const std::string& mode) { // TODO @@ -74,6 +87,21 @@ ::std::future PsLocalClient::save(uint32_t table_id, return done(); } +::std::future PsLocalClient::Save( + const LoadSaveContext& save_context) { + if (save_context.table_id < 0) { + for (auto& it : _table_map) { + save(it.first, save_context.epoch, save_context.mode); + } + return done(); + } else { + auto* table_ptr = table(save_context.table_id); + table_ptr->flush(); + table_ptr->save(save_context.epoch, save_context.mode); + return done(); + } +} + ::std::future PsLocalClient::clear() { // TODO return done(); @@ -93,6 +121,51 @@ ::std::future PsLocalClient::stop_server() { return done(); } +::std::future PsLocalClient::Pull(RequestContext& pull_context) { + if (pull_context.value_type == Dense) { // pull dense + Region* dense_region = reinterpret_cast(pull_context.dense_values); + pull_dense(dense_region, pull_context.num, pull_context.table); + } else { // pull sparse + uint64_t* keys = reinterpret_cast(pull_context.keys); + char** select_values = reinterpret_cast(pull_context.sparse_values); + size_t table_id = pull_context.table; + size_t num = pull_context.num; + pull_sparse_ptr(select_values, table_id, keys, num); + } +} + +::std::future PsLocalClient::Push(RequestContext& push_context) { + if (push_context.value_type == Dense) { // push dense + if (push_context.training_phase == Init) { + const Region* regions = push_context.push_context.push_dense_values; + size_t region_num = push_context.num; + push_dense_param(regions, region_num, push_context.table); + } else { + if (push_context.training_mode == Geo) { // geo + float* total_send_data = + reinterpret_cast(push_context.dense_values); + size_t total_send_data_size = push_context.num; + push_dense_raw_gradient(push_context.table, total_send_data, + total_send_data_size, push_context.callback); + } else { // async and sync + const Region* regions = push_context.push_context.push_dense_values; + size_t region_num = push_context.num; + push_dense(regions, region_num, push_context.table); + } + } + } else { // push sparse + if (push_context.training_mode == Async) { + const uint64_t* keys = push_context.push_context.keys; + const float** update_values = push_context.push_context.push_values; + size_t table_id = push_context.table; + size_t num = push_context.num; + push_sparse(table_id, keys, update_values, num); + } else { + // TODO + } + } +} + ::std::future PsLocalClient::pull_dense(Region* regions, size_t region_num, size_t table_id) { diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.h b/paddle/fluid/distributed/ps/service/ps_local_client.h index e73974ac56286..83ca558e3d2cb 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.h +++ b/paddle/fluid/distributed/ps/service/ps_local_client.h @@ -39,12 +39,16 @@ class PsLocalClient : public PSClient { virtual ::std::future load(uint32_t table_id, const std::string& epoch, const std::string& mode) override; + virtual std::future Load( + const LoadSaveContext& load_context) override; virtual ::std::future save(const std::string& epoch, const std::string& mode) override; virtual ::std::future save(uint32_t table_id, const std::string& epoch, const std::string& mode) override; + virtual std::future Save( + const LoadSaveContext& save_context) override; virtual ::std::future clear() override; virtual ::std::future clear(uint32_t table_id) override; @@ -55,6 +59,10 @@ class PsLocalClient : public PSClient { virtual ::std::future pull_dense(Region* regions, size_t region_num, size_t table_id); + virtual ::std::future Pull(RequestContext& pull_context) override; + + virtual ::std::future Push(RequestContext& push_context) override; + virtual ::std::future push_dense(const Region* regions, size_t region_num, size_t table_id); diff --git a/paddle/fluid/distributed/ps/service/ps_local_server.h b/paddle/fluid/distributed/ps/service/ps_local_server.h index 91f8bc4c91271..31b52126fc576 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_server.h +++ b/paddle/fluid/distributed/ps/service/ps_local_server.h @@ -28,7 +28,6 @@ class PsLocalServer : public PSServer { virtual uint64_t start() { return 0; } virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual int32_t stop() { return 0; } - virtual int32_t port() { return 0; } virtual int32_t configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program = {}) { diff --git a/paddle/fluid/distributed/ps/service/server.cc b/paddle/fluid/distributed/ps/service/server.cc index 5f1974e3e610c..893f671359e40 100644 --- a/paddle/fluid/distributed/ps/service/server.cc +++ b/paddle/fluid/distributed/ps/service/server.cc @@ -67,8 +67,6 @@ int32_t PSServer::configure( _config = config.server_param(); _rank = server_rank; _environment = &env; - _shuffled_ins = - paddle::framework::MakeChannel>(); size_t shard_num = env.get_ps_servers().size(); const auto &downpour_param = _config.downpour_server_param(); diff --git a/paddle/fluid/distributed/ps/service/server.h b/paddle/fluid/distributed/ps/service/server.h index 160d4a6128295..d2804405b4198 100644 --- a/paddle/fluid/distributed/ps/service/server.h +++ b/paddle/fluid/distributed/ps/service/server.h @@ -69,11 +69,6 @@ class PSServer { const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program = {}); - // return server_ip - virtual std::string ip() { return butil::my_ip_cstr(); } - // return server_port - virtual int32_t port() = 0; - virtual uint64_t start(const std::string &ip, uint32_t port) = 0; virtual int32_t stop() = 0; @@ -94,15 +89,6 @@ class PSServer { return &_table_map; } - typedef std::function MsgHandlerFunc; - virtual int registe_pserver2pserver_msg_handler(int msg_type, - MsgHandlerFunc handler) { - _msg_handler_map[msg_type] = handler; - return 0; - } - - paddle::framework::Channel> _shuffled_ins; - protected: virtual int32_t initialize() = 0; @@ -111,7 +97,6 @@ class PSServer { ServerParameter _config; PSEnvironment *_environment; std::unordered_map> _table_map; - std::unordered_map _msg_handler_map; protected: std::shared_ptr scope_; diff --git a/paddle/fluid/distributed/ps/table/accessor.h b/paddle/fluid/distributed/ps/table/accessor.h index 7c91a60864980..07c211bb9c128 100644 --- a/paddle/fluid/distributed/ps/table/accessor.h +++ b/paddle/fluid/distributed/ps/table/accessor.h @@ -45,6 +45,17 @@ struct DataConverter { std::string deconverter; }; +struct AccessorInfo { + size_t dim; + size_t size; + size_t select_size; + size_t select_dim; + size_t update_size; + size_t update_dim; + size_t mf_size; + size_t fea_dim; +}; + class ValueAccessor { public: ValueAccessor() {} @@ -68,6 +79,8 @@ class ValueAccessor { } virtual int initialize() = 0; + virtual void GetTableInfo(AccessorInfo& info) = 0; + // value维度 virtual size_t dim() = 0; // value各个维度的size @@ -163,6 +176,7 @@ class ValueAccessor { TableAccessorParameter _config; std::unordered_map> _data_coverter_map; + AccessorInfo _accessor_info; }; REGISTER_PSCORE_REGISTERER(ValueAccessor); } // namespace distributed diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index 607469e2f7b0d..cc0f5867a3d65 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) { return 0; } +int32_t CommonDenseTable::Pull(TableContext& context) { + CHECK(context.value_type == Dense); + float* pull_values = context.pull_context.values; + return pull_dense(pull_values, context.num); +} + +int32_t CommonDenseTable::Push(TableContext& context) { + CHECK(context.value_type == Dense); + if (context.pull_context.values != nullptr) { + const float* values = context.push_context.values; + return push_dense(values, context.num); + } + return 0; +} + int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), pull_values); diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.h b/paddle/fluid/distributed/ps/table/common_dense_table.h index a4c0f29ddb877..cad49a0a449c4 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.h +++ b/paddle/fluid/distributed/ps/table/common_dense_table.h @@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable { const std::string& name); virtual int32_t initialize_value(); virtual int32_t initialize_optimizer(); + virtual int32_t Pull(TableContext& context); + virtual int32_t Push(TableContext& context); int32_t pull_dense(float* pull_values, size_t num) override; int32_t push_dense_param(const float* values, size_t num) override; int32_t push_dense(const float* values, size_t num) override; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 7946569525cc4..f6f127621b947 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -454,6 +454,9 @@ class GraphTable : public SparseTable { int32_t get_server_index_by_id(int64_t id); Node *find_node(int64_t id); + virtual int32_t Pull(TableContext &context) { return 0; } + virtual int32_t Push(TableContext &context) { return 0; } + virtual int32_t pull_sparse(float *values, const PullSparseValue &pull_value) { return 0; diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.cc b/paddle/fluid/distributed/ps/table/common_sparse_table.cc index b44d08b937a96..45be53335e1a1 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.cc @@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() { return 0; } +int32_t CommonSparseTable::Pull(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.use_ptr) { + char** pull_values = context.pull_context.ptr_values; + const uint64_t* keys = context.pull_context.keys; + return pull_sparse_ptr(pull_values, keys, context.num); + } else { + float* pull_values = context.pull_context.values; + const PullSparseValue& pull_value = context.pull_context.pull_value; + return pull_sparse(pull_values, pull_value); + } +} + +int32_t CommonSparseTable::Push(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.pull_context.values != nullptr) { + const float* values = context.push_context.values; + const uint64_t* keys = context.push_context.keys; + return push_sparse(keys, values, context.num); + } else { + const float** values = context.push_context.ptr_values; + const uint64_t* keys = context.push_context.keys; + return push_sparse(keys, values, context.num); + } +} + int32_t CommonSparseTable::pull_sparse(float* pull_values, const PullSparseValue& pull_value) { auto shard_num = task_pool_size_; diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.h b/paddle/fluid/distributed/ps/table/common_sparse_table.h index 82481dcd584e4..138c544742066 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.h @@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable { virtual int32_t push_dense(const float* values, size_t num) { return 0; } // unused method end + virtual int32_t Pull(TableContext& context); + virtual int32_t Push(TableContext& context); + virtual int32_t initialize(); virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_value(); diff --git a/paddle/fluid/distributed/ps/table/common_table.h b/paddle/fluid/distributed/ps/table/common_table.h index bac826dfe0e20..3d291c0152246 100644 --- a/paddle/fluid/distributed/ps/table/common_table.h +++ b/paddle/fluid/distributed/ps/table/common_table.h @@ -119,6 +119,9 @@ class BarrierTable : public Table { virtual void *get_shard(size_t shard_idx) { return 0; } + virtual int32_t Pull(TableContext &context) { return 0; } + virtual int32_t Push(TableContext &context) { return 0; } + int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; } diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_accessor.cc index 866bd8114ccea..43e143dca901b 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.cc @@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() { return 0; } +void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) { + info.dim = dim(); + info.size = size(); + info.select_dim = select_dim(); + info.select_size = select_size(); + info.update_dim = update_dim(); + info.update_size = update_size(); + info.fea_dim = fea_dim(); +} + size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::dim_size(size_t dim) { diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.h b/paddle/fluid/distributed/ps/table/ctr_accessor.h index 1e31fec04649b..bc46217955a8a 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.h @@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor { virtual int initialize(); virtual ~CtrCommonAccessor() {} + virtual void GetTableInfo(AccessorInfo& info); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc index b07bcf70ad7af..bccf1fdebafa0 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc @@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() { return 0; } +void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) { + info.dim = dim(); + info.size = size(); + info.select_dim = select_dim(); + info.select_size = select_size(); + info.update_dim = update_dim(); + info.update_size = update_size(); + info.fea_dim = fea_dim(); +} + size_t DownpourCtrDoubleAccessor::dim() { auto embedx_dim = _config.embedx_dim(); return DownpourCtrDoubleFeatureValue::dim(embedx_dim); diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h index d7c717ace0988..d7942634e8600 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h @@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {} virtual int initialize(); + virtual void GetTableInfo(AccessorInfo& info); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/depends/sparse_utils.h b/paddle/fluid/distributed/ps/table/depends/sparse_utils.h index 708f7786bf3b0..98e0250acc4d6 100644 --- a/paddle/fluid/distributed/ps/table/depends/sparse_utils.h +++ b/paddle/fluid/distributed/ps/table/depends/sparse_utils.h @@ -58,7 +58,7 @@ struct PullSparseValue { std::vector* offset_shard) const { offset_shard->reserve(numel_ / shard_num + 1); for (int x = 0; x < numel_; ++x) { - if (feasigns_[x] % shard_num == shard_id) { + if (int(feasigns_[x] % shard_num) == shard_id) { offset_shard->push_back(x); } } diff --git a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc index 5f22c3a436f1f..e8ca7430351de 100644 --- a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc @@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() { return 0; } +void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) { + info.dim = dim(); + info.size = size(); + info.select_dim = select_dim(); + info.select_size = select_size(); + info.update_dim = update_dim(); + info.update_size = update_size(); + info.fea_dim = fea_dim(); +} + size_t DownpourCtrAccessor::dim() { auto embedx_dim = _config.embedx_dim(); return DownpourCtrFeatureValue::dim(embedx_dim); diff --git a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h index 5de7b12e01f0d..11991ad044ff6 100644 --- a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h @@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor { virtual ~DownpourCtrAccessor() {} virtual int initialize(); + virtual void GetTableInfo(AccessorInfo& info); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h index 89c4fc15ae279..3b43f99543fdd 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable { virtual int32_t save(const std::string& path, const std::string& param) { return 0; } + virtual int32_t Pull(TableContext& context) { return 0; } + virtual int32_t Push(TableContext& context) { return 0; } virtual int32_t flush() { return 0; } virtual int32_t shrink(const std::string& param) { return 0; } virtual void clear() { return; } diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index 7ce6e9005cf56..98454ca747d31 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -390,6 +390,26 @@ std::pair MemorySparseTable::print_table_stat() { return {feasign_size, mf_size}; } +int32_t MemorySparseTable::Pull(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.use_ptr) { + char** pull_values = context.pull_context.ptr_values; + const uint64_t* keys = context.pull_context.keys; + return pull_sparse_ptr(pull_values, keys, context.num); + } else { + float* pull_values = context.pull_context.values; + const PullSparseValue& pull_value = context.pull_context.pull_value; + return pull_sparse(pull_values, pull_value); + } +} + +int32_t MemorySparseTable::Push(TableContext& context) { + CHECK(context.value_type == Sparse); + + const uint64_t* keys = context.push_context.keys; + return push_sparse(keys, context.push_context.ptr_values, context.num); +} + int32_t MemorySparseTable::pull_sparse(float* pull_values, const PullSparseValue& pull_value) { CostTimer timer("pserver_sparse_select_all"); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index 5770f25f8f41d..d26c67319760d 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable { virtual int32_t push_dense(const float* values, size_t num) { return 0; } // unused method end + virtual int32_t Pull(TableContext& context); + virtual int32_t Push(TableContext& context); + virtual int32_t initialize(); virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_value(); diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc index 60514b4e19ffa..5bc58bc5a1108 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() { return 0; } +int32_t SSDSparseTable::Pull(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.use_ptr) { + char** pull_values = context.pull_context.ptr_values; + const uint64_t* keys = context.pull_context.keys; + return pull_sparse_ptr(pull_values, keys, context.num); + } else { + float* pull_values = context.pull_context.values; + const PullSparseValue& pull_value = context.pull_context.pull_value; + return pull_sparse(pull_values, pull_value); + } +} + +int32_t SSDSparseTable::Push(TableContext& context) { return 0; } + int32_t SSDSparseTable::pull_sparse(float* pull_values, const PullSparseValue& pull_value) { auto shard_num = task_pool_size_; diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h index f5e8a7067e0e0..3a703d7d966d3 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h @@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable { // exchange data virtual int32_t update_table(); + virtual int32_t Pull(TableContext& context); + virtual int32_t Push(TableContext& context); + virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index da1bb668ccfa3..2bd2a42b6c58f 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -32,6 +32,30 @@ namespace paddle { namespace distributed { + +enum ValueType { Sparse = 0, Dense = 1 }; + +struct PullContext { + const uint64_t *keys; + const PullSparseValue pull_value; + float *values; + char **ptr_values; +}; + +struct TablePushContext { + const uint64_t *keys; + const float *values; + const float **ptr_values; +}; + +struct TableContext { + ValueType value_type; + PullContext pull_context; + TablePushContext push_context; + size_t num; + bool use_ptr; +}; + class Table { public: Table() {} @@ -39,6 +63,8 @@ class Table { virtual int32_t initialize(const TableParameter &config, const FsClientParameter &fs_config); + virtual int32_t Pull(TableContext &context) = 0; + virtual int32_t Push(TableContext &context) = 0; virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t push_dense(const float *values, size_t num) = 0; // for push global_step diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.cc b/paddle/fluid/distributed/ps/table/tensor_accessor.cc index 70a580c1e53a9..8c5349bff832c 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.cc +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.cc @@ -20,6 +20,16 @@ namespace distributed { int CommMergeAccessor::initialize() { return 0; } +void CommMergeAccessor::GetTableInfo(AccessorInfo &info) { + info.dim = dim(); + info.size = size(); + info.select_dim = select_dim(); + info.select_size = select_size(); + info.update_dim = update_dim(); + info.update_size = update_size(); + info.fea_dim = fea_dim(); +} + // value 维度 size_t CommMergeAccessor::dim() { return 0; } diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.h b/paddle/fluid/distributed/ps/table/tensor_accessor.h index 5041b8fdf8733..1873b743b44ec 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.h +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.h @@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor { CommMergeAccessor() {} virtual ~CommMergeAccessor() {} virtual int initialize(); + virtual void GetTableInfo(AccessorInfo &info); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/tensor_table.h b/paddle/fluid/distributed/ps/table/tensor_table.h index 64d81327acc55..23a62365c0f5a 100644 --- a/paddle/fluid/distributed/ps/table/tensor_table.h +++ b/paddle/fluid/distributed/ps/table/tensor_table.h @@ -48,6 +48,8 @@ class TensorTable : public Table { TensorTable() {} virtual ~TensorTable() {} + virtual int32_t Pull(TableContext &context) { return 0; } + virtual int32_t Push(TableContext &context) { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; } diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 0588dbdf0fc61..c887cfeb71eef 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false; std::shared_ptr FleetWrapper::pserver_ptr_ = NULL; +void FleetWrapper::Stop() { StopServer(); } + +void FleetWrapper::Load(WrapperContext& context) { + auto table_id = context.table_id; + if (table_id >= 0 && context.meta != "") { + LoadSparseOnServer(context.path, context.meta, context.table_id); + return; + } + if (table_id < 0) { // laod all + LoadModel(context.path, context.mode); + } else { // load one table + LoadModelOneTable(table_id, context.path, context.mode); + } + return; +} + +void FleetWrapper::Save(WrapperContext& context) { + auto table_id = context.table_id; + if (table_id < 0) { + SaveModel(context.path, context.mode); + } else { + SaveModelOneTable(table_id, context.path, context.mode); + } + return; +} + void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry) { diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h index a535b8c5bf8f9..d68c453c6d51b 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h" +#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/shell.h" @@ -54,7 +55,7 @@ using framework::Variable; using RpcCtxMap = std::unordered_map; -class FleetWrapper { +class FleetWrapper : public PSWrapper { public: virtual ~FleetWrapper() {} FleetWrapper() { @@ -68,7 +69,13 @@ class FleetWrapper { // pserver request max retry client2client_max_retry_ = 3; } + virtual int32_t Initialize(InitContext& context) { return 0; } + virtual void Stop() override; + + virtual void Load(WrapperContext& context) override; + + virtual void Save(WrapperContext& context) override; // set client to client communication config void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry); diff --git a/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h b/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h index c92835aa995ad..ca02ad31195ef 100755 --- a/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h +++ b/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h @@ -1,18 +1,84 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ -#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ - -#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" +#include "paddle/fluid/distributed/ps/service/ps_service/service.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/io/shell.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace framework { +class Scope; +class SelectedRows; +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace distributed { + +class PSCore; + +using framework::LoDTensor; +using framework::Scope; +using phi::SelectedRows; +using framework::Variable; + +using RpcCtxMap = std::unordered_map; + +struct WrapperContext { + uint32_t table_id; + const std::string path; + const int mode; + const std::string meta; +}; + +struct InitContext { + const std::vector dev_ids; // for gpu +}; + +class PSWrapper { + public: + virtual ~PSWrapper() {} + PSWrapper() {} + // init server + + virtual int32_t Initialize(InitContext& context) = 0; + + virtual void Stop() = 0; + + virtual void Load(WrapperContext& context) = 0; + + virtual void Save(WrapperContext& context) = 0; +}; + +} // end namespace distributed +} // end namespace paddle