Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

统一ps refine #41234

Merged
merged 17 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 147 additions & 206 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc

Large diffs are not rendered by default.

185 changes: 87 additions & 98 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService {
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}

virtual int32_t configure(PSClient *client, size_t rank_id) {
virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
Expand Down Expand Up @@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient {
BrpcPsClient() {}
virtual ~BrpcPsClient() {
if (_running) {
flush();
Flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
Expand All @@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient {
_server_started = false;
}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
std::future<int32_t> shrink(uint32_t table_id,
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
std::future<int32_t> load(const std::string &epoch,
std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;

std::future<int32_t> Load(const LoadSaveContext &load_context) override;

std::future<int32_t> save(const std::string &epoch,
std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override;

std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;

virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;

std::future<int32_t> clear() override;

std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> Clear() override;

std::future<int32_t> stop_server() override;
std::future<int32_t> Clear(uint32_t table_id) override;

std::future<int32_t> start_profiler() override;
std::future<int32_t> stop_profiler() override;
std::future<int32_t> StopServer() override;

void finalize_worker() override;
std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;

virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
void FinalizeWorker() override;

virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id);

virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t table_id);
void push_dense_task_consume();
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id);

virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);

virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id);

virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);

virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> Flush();

virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> flush();

std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
const std::string &msg) override;
std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id,
const std::string &msg) override;

// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);

void print_queue_size();
void print_queue_size_thread();
void PrintQueueSize();
void PrintQueueSizeThread();

protected:
virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) {
virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get();
}
int32_t initialize() override;
int32_t Initialize() override;

private:
// virtual int32_t initialize() override;

inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}

std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);

std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);

bool _running = false;
bool _flushing = false;
Expand All @@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient {

std::thread _print_thread;

int push_sparse_async_shard_merge(
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor);

int push_sparse_async_shard_push(
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
Expand All @@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::future<int32_t> push_dense_raw_gradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;

std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;

std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;

std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void push_sparse_task_consume();
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;

std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;

std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num, void *done,
int pserver_idx) override;

std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys,
const float **update_values, size_t num,
void *done) override;
std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void PushSparseTaskConsume();

private:
int32_t start_client_service();
int32_t StartClientService();

void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
Expand Down
Loading