From 93dd3bd7450464af2dae81bb29e7edd5214cce11 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 Date: Wed, 26 Jan 2022 09:19:33 +0000 Subject: [PATCH] geo depends --- cmake/third_party.cmake | 3 +- .../distributed/ps/service/brpc_ps_client.cc | 146 ++++++++++++++++-- .../distributed/ps/service/brpc_ps_client.h | 4 + .../ps/service/communicator/communicator.cc | 32 ++-- .../ps/service/communicator/communicator.h | 6 +- .../fluid/distributed/ps/service/ps_client.h | 11 ++ .../fluid/distributed/ps/table/CMakeLists.txt | 5 +- .../ps/table/depends/geo_recorder.h | 4 - paddle/fluid/distributed/ps/table/table.cc | 2 + paddle/fluid/distributed/test/CMakeLists.txt | 3 + .../distributed/fleet/runtime/the_one_ps.py | 3 +- 11 files changed, 185 insertions(+), 34 deletions(-) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index ac3eff04d5383..2f07e0f7e5c48 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -211,8 +211,9 @@ include(external/dlpack) # download dlpack include(external/xxhash) # download, build, install xxhash include(external/warpctc) # download, build, install warpctc include(external/utf8proc) # download, build, install utf8proc +include(external/libmct) # download, build, install libmct -list(APPEND third_party_deps extern_eigen3 extern_gflags extern_glog extern_boost extern_xxhash) +list(APPEND third_party_deps extern_eigen3 extern_gflags extern_glog extern_boost extern_xxhash extern_libmct) list(APPEND third_party_deps extern_zlib extern_dlpack extern_warpctc extern_threadpool extern_utf8proc) include(external/lapack) # download, build, install lapack diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index e855fcbd02553..301136794d483 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() { auto &profiler = CostProfiler::instance(); profiler.register_profiler("pserver_client_pull_dense"); profiler.register_profiler("pserver_client_pull_sparse"); + profiler.register_profiler("pserver_client_pull_sparse_param"); profiler.register_profiler("pserver_client_pull_sparse_local"); profiler.register_profiler("pserver_client_push_sparse"); profiler.register_profiler("pserver_client_push_sparse_parse"); @@ -543,6 +544,7 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, return fut; } +// for GEO std::future BrpcPsClient::push_sparse_param( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { @@ -558,18 +560,8 @@ std::future BrpcPsClient::push_sparse_param( ids.resize(request_call_num); value_ptrs.resize(request_call_num); - const auto &server_param = _config.server_param().downpour_server_param(); - uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; - for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { - const auto &table_param = server_param.downpour_table_param(i); - if (table_param.table_id() == table_id) { - shard_num = table_param.shard_num(); - break; - } - } - for (size_t i = 0; i < num; ++i) { - size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]); + size_t pserver_idx = keys[i] % request_call_num; ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } @@ -1003,6 +995,120 @@ std::future BrpcPsClient::pull_sparse(float **select_values, return fut; } +// for GEO +std::future BrpcPsClient::pull_sparse_param(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, + bool is_training) { + auto timer = std::make_shared("pserver_client_pull_sparse_param"); + size_t request_call_num = _server_channels.size(); + + auto shard_sorted_kvs = std::make_shared< + std::vector>>>(); + shard_sorted_kvs->resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t shard_id = keys[i] % request_call_num; + shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); + } + + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->select_size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [shard_sorted_kvs, value_size](void *done) { + int ret = 0; + auto *closure = reinterpret_cast(done); + for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) { + if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + + auto &request_kvs = shard_sorted_kvs->at(i); + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + uint64_t last_key = UINT64_MAX; + float *last_value_data = NULL; + + // can remove sort&unique + for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { + auto *kv_pair = &(request_kvs[kv_idx]); + if (kv_pair->first == last_key) { + memcpy(reinterpret_cast(kv_pair->second), + reinterpret_cast(last_value_data), value_size); + } else { + last_key = kv_pair->first; + last_value_data = kv_pair->second; + if (value_size != + io_buffer_itr.copy_and_forward( + reinterpret_cast(last_value_data), value_size)) { + LOG(WARNING) << "res data is lack or not in format"; + ret = -1; + break; + } + } + } + } + closure->set_promise_value(ret); + }); + closure->add_timer(timer); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kvs = shard_sorted_kvs->at(i); + std::sort(sorted_kvs.begin(), sorted_kvs.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + uint64_t last_key = UINT64_MAX; + uint32_t kv_request_count = 0; + size_t sorted_kv_size = sorted_kvs.size(); + auto &request_buffer = closure->cntl(i)->request_attachment(); + + request_buffer.append(reinterpret_cast(&is_training), sizeof(bool)); + std::vector keys_counter; + keys_counter.reserve(sorted_kv_size); + + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + ++kv_request_count; + uint32_t keys = 1; + last_key = sorted_kvs[kv_idx].first; + request_buffer.append(reinterpret_cast(&last_key), + sizeof(uint64_t)); + while (kv_idx < sorted_kv_size - 1 && + last_key == sorted_kvs[kv_idx + 1].first) { + ++kv_idx; + ++keys; + } + keys_counter.push_back(keys); + } + + request_buffer.append(reinterpret_cast(keys_counter.data()), + sizeof(uint32_t) * keys_counter.size()); + + if (kv_request_count == 0) { + closure->Run(); + } else { + closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&kv_request_count, // NOLINT + sizeof(uint32_t)); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + } + return fut; +} + std::future BrpcPsClient::send_client2client_msg( int msg_type, int to_client_id, const std::string &msg) { auto promise = std::make_shared>(); @@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, std::string var_name = ""; int64_t var_num = 0; int64_t var_shape = 0; + std::string table_class; const auto &worker_param = _config.worker_param().downpour_worker_param(); for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { if (worker_param.downpour_table_param(i).table_id() == table_id) { var_name = worker_param.downpour_table_param(i).common().table_name(); var_num = worker_param.downpour_table_param(i).common().table_num(); var_shape = worker_param.downpour_table_param(i).common().table_dim(); + table_class = worker_param.downpour_table_param(i).table_class(); break; } } @@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, save_vec.push_back(save_huge_vec.data() + i * var_shape); } - auto status = pull_sparse(reinterpret_cast(save_vec.data()), - table_id, save_key.data(), save_key.size(), true); - status.wait(); + VLOG(2) << "recv_and_save_table: table_class: " << table_class; + // TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its + // recv_and_save_table + if (table_class == "MemorySparseGeoTable") { + auto status = + pull_sparse_param(reinterpret_cast(save_vec.data()), table_id, + save_key.data(), save_key.size(), true); + status.wait(); + } else { + auto status = pull_sparse(reinterpret_cast(save_vec.data()), + table_id, save_key.data(), save_key.size(), true); + status.wait(); + } // create lod tensor std::shared_ptr scope; diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index 70f406ee248dc..59ed59933db86 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient { size_t table_id, const uint64_t *keys, size_t num, bool is_training); + virtual std::future pull_sparse_param(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, bool is_training); virtual std::future print_table_stat(uint32_t table_id); diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index a73f87c1d8896..3f1667e5344d6 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, bool training = true; - auto status = _worker_ptr->pull_sparse( + auto status = _worker_ptr->pull_sparse_param( (float **)push_g_vec.data(), table_id, // NOLINT sparse_push_keys.data(), sparse_push_keys.size(), training); status.wait(); @@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector &var_names, auto &sparse_ids_set = iter.second; auto sparse_ids_vec = std::make_shared>(); sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end()); - sparse_id_queues_.at(key)->Push(sparse_ids_vec); + sparse_id_queues_.at(key)->Put(sparse_ids_vec); VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key << "'s queue"; } @@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, for (auto &iter : send_varname_to_ctx_) { auto &ctx = iter.second; - if (!ctx.is_sparse) continue; + if (!ctx.is_sparse) { + parallel_task_nums_ += 1; + continue; + } auto &varnames = ctx.origin_varnames; PADDLE_ENFORCE_EQ( varnames.size(), 1, @@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, for (auto &splited_var : ctx.splited_varnames) { parallel_task_nums_ += 1; sparse_id_queues_.insert( - std::pair>>>>( + std::pair>>>( splited_var, - std::make_shared< - BlockingQueue>>>( - send_queue_size_))); + paddle::framework::MakeChannel< + std::shared_ptr>>(send_queue_size_))); } } @@ -1242,8 +1244,8 @@ std::vector GeoCommunicator::MergeSparseIds( VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; if (sparse_id_queues_.at(send_varname)->Size() > 0) { wait_times = 0; - std::shared_ptr> pop_ids = - sparse_id_queues_.at(send_varname)->Pop(); + std::shared_ptr> pop_ids = nullptr; + sparse_id_queues_.at(send_varname)->Get(pop_ids); for (size_t j = 0; j < pop_ids->size(); j++) { sparse_ids.insert(pop_ids->at(j)); } @@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, std::vector &sparse_ids, int table_id, int ep_idx) { platform::RecordEvent record_event("GeoCommunicator->SendSparse"); + if (sparse_ids.size() == 0) { + return; + } std::string param_name = SplitedGradToParam(varname); VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id @@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname, t_value + j * dims1, t_old->data() + sparse_ids[j] * dims1); push_g_vec.push_back(t_value + j * dims1); + + VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key " + << sparse_ids[j] << " value[0] " << push_g_vec[j][0] + << " value[-1] " << push_g_vec[j][dims1 - 1]; } ++_async_call_num; @@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, cpu_ctx); for (auto j = 0; j < static_cast(keys.size()); ++j) { + VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j] + << "value[0] " << values[j * dims1] << " value[-1] " + << values[j * dims1 + dims1 - 1]; float *latest_data = t_latest->data() + keys[j] * dims1; float *old_data = t_old->data() + keys[j] * dims1; // pserver - old => delta diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h index 570e668d9d5d2..c63f341607439 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" @@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator { // parameter on pserver std::shared_ptr pserver_scope_; - std::unordered_map< - std::string, - std::shared_ptr>>>> + std::unordered_map>>> sparse_id_queues_; }; diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 7db8b0c124459..21719fbdbf1d6 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -128,6 +128,17 @@ class PSClient { const uint64_t *keys, size_t num, bool is_training) = 0; + virtual std::future pull_sparse_param(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, bool is_training) { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + virtual ::std::future pull_sparse_ptr(char **select_values, size_t table_id, const uint64_t *keys, diff --git a/paddle/fluid/distributed/ps/table/CMakeLists.txt b/paddle/fluid/distributed/ps/table/CMakeLists.txt index b0a553f210044..9aa9ecc2afdcf 100644 --- a/paddle/fluid/distributed/ps/table/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/table/CMakeLists.txt @@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) -cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) +set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table) + +cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) target_link_libraries(table -fopenmp) diff --git a/paddle/fluid/distributed/ps/table/depends/geo_recorder.h b/paddle/fluid/distributed/ps/table/depends/geo_recorder.h index ad094f0dfbc48..adab0ee344bca 100644 --- a/paddle/fluid/distributed/ps/table/depends/geo_recorder.h +++ b/paddle/fluid/distributed/ps/table/depends/geo_recorder.h @@ -15,13 +15,9 @@ #pragma once #include -#include #include // NOLINT #include -#include -#include #include -#include #include namespace paddle { diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index b9b5ff12fc97a..fa8169da07ab7 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/distributed/ps/table/common_dense_table.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/distributed/ps/table/common_sparse_table.h" +#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h" #include "paddle/fluid/distributed/ps/table/sparse_geo_table.h" #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h" @@ -43,6 +44,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, DenseTensorTable); REGISTER_PSCORE_CLASS(Table, GlobalStepTable); REGISTER_PSCORE_CLASS(Table, MemorySparseTable); +REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable); REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor); REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule); diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index 62de82832e133..2223334ccc442 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -35,3 +35,6 @@ cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost ta set_source_files_properties(memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS ${COMMON_DEPS} boost table) + +set_source_files_properties(memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(memory_sparse_geo_table_test SRCS memory_geo_table_test.cc DEPS ${COMMON_DEPS} boost table) diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index c561c250678b0..cc81f8b3e9e1c 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -943,7 +943,7 @@ def _get_tables(): ctx.origin_varnames()[0]] if self.compiled_strategy.is_geo_mode(): - table.table_class = "SparseGeoTable" + table.table_class = "MemorySparseGeoTable" else: all_table_proto = self.context[ "user_defined_strategy"].sparse_table_configs @@ -1306,6 +1306,7 @@ def _ps_inference_save_inference_model(self, is_dense=True, split_dense_table=self.role_maker._is_heter_parameter_server_mode, use_origin_program=True) + # TODO(zhaocaibei123): for GEO: should call GeoCommunicator::RecvDense self._communicator.pull_dense(denses) generate_vars = self.context[