Skip to content

Commit

Permalink
Ssd sparse table (#41812)
Browse files Browse the repository at this point in the history
* [cherry-pick2.3]fix compile bug of windows cuda11.5 (#41464)

cherry-pick

fix compile bug of windows cuda11.5 #41433

* fix bug of missing boost when compile cache.cc (#41449)

【chery-pick #41430】fix bug of random compile failure, due to incorrect compile order of dependencies

* Fix eager try catch (#41438) (#41477)

[Cherry-Pick]Fix eager try catch (#41438)

* Cherry-pick-PR41407, fix device_id bug for final_state op in multiprocess testcase (#41407) (#41475)

Cherry-pick PR #41407

* [BugFix] Add error hint for one_hot gpu version (#41335) (#41495)

* add one_hot gpu hint

* move allow_out_of_range judgement

* delete useless unittest

* fix bugs of reshape double grad infermeta (#41459) (#41493)

* [cherrypick-2.3] modify infer gpu memory strategy (#41427), remove cudnn_deterministic=True (#41341)  (#41491)

Co-authored-by: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>

* [Cherry-pick][ROCm] fix dcu error in device event base, test=develop (#41523)

Cherry-pick of #41521

* [Cherry-Pick]Cherry pick PR41200, PR41474, PR41382 (#41509)

* Use `self`as a parameter of _hash_with_id function to avoid error caused by hash_id reuse (#41200)

* Add fill_constant_batch_size YAML and UT (#41474)

* Switch some dy2st UT to eager mode (#41382)

* Sitch some dy2st UT to eager mode

* Fix test_lstm and remove test_transformer

* Run test_resnet_v2 in old dy mode

* Unittest recover (#41431)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix

* remove some interface

* fix

* remove

* code style

* recover

* fix

* remove code unused

* remove some unused table & accessor & CommonDenseTable => MemoryDenseTable

* fix

* fix

* fix

* recover

* remove unused code

* recover unittest

* fix

* remove

* fix

* remove code unuseful

* remove

* fix

* recover

* remove

Co-authored-by: esythan <esythan@126.com>

* add ssd sparse table

* fix

* add cache shuffle

* fix

* fix

* fix

* fix

* fix

* fix

* add unit test

* fix

Co-authored-by: Zhou Wei <1183042833@qq.com>
Co-authored-by: Sing_chan <51314274+betterpig@users.noreply.github.com>
Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: Siming Dai <908660116@qq.com>
Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: Zhang Jun <ewalker@live.cn>
Co-authored-by: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: Qi Li <qili93@qq.com>
Co-authored-by: esythan <esythan@126.com>
  • Loading branch information
11 people committed Apr 22, 2022
1 parent 4fd190d commit cca57c4
Show file tree
Hide file tree
Showing 37 changed files with 1,526 additions and 62 deletions.
6 changes: 2 additions & 4 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,8 @@ if (WITH_PSCORE)
include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)

if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()

if(WITH_XBYAK)
Expand Down
70 changes: 70 additions & 0 deletions paddle/fluid/distributed/common/topk_calculator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <queue>
#include <unordered_map>

namespace paddle {
namespace distributed {
class TopkCalculator {
public:
TopkCalculator(int shard_num, size_t k)
: _shard_num(shard_num), _total_max_size(k) {
_shard_max_size = _total_max_size / shard_num;
_shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1;
for (int i = 0; i < shard_num; ++i) {
_mpq.emplace(i, std::priority_queue<double, std::vector<double>,
std::greater<double>>());
}
}
~TopkCalculator() {}
bool push(int shard_id, double value) {
if (_mpq.find(shard_id) == _mpq.end()) {
return false;
}
auto &pq = _mpq[shard_id];
if (pq.size() < _shard_max_size) {
pq.push(value);
} else {
if (pq.top() < value) {
pq.pop();
pq.push(value);
}
}
return true;
}
// TODO 再进行一次堆排序merge各个shard的结果
int top() {
double total = 0;
for (const auto &item : _mpq) {
auto &pq = item.second;
if (!pq.empty()) {
total += pq.top();
}
}
return total / _shard_num;
}

private:
std::unordered_map<int, std::priority_queue<double, std::vector<double>,
std::greater<double>>>
_mpq;
int _shard_num;
size_t _total_max_size;
size_t _shard_max_size;
};

} // namespace distributed
} // namespace paddle
6 changes: 5 additions & 1 deletion paddle/fluid/distributed/ps/service/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})

set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
if(WITH_HETERPS)
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb)
else()
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
endif()

brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
Expand Down
76 changes: 76 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,82 @@ std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}

std::future<int32_t> BrpcPsClient::CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle";
return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold});
}

std::future<int32_t> BrpcPsClient::CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path";
std::vector<std::string> param;
param.push_back(path);
param.push_back(mode);
param.push_back(cache_threshold);
for (size_t i = 0; i < tables.size(); i++) {
param.push_back(std::to_string(tables[i]));
}
return SendSaveCmd(0, PS_CACHE_SHUFFLE, param);
}

std::future<int32_t> BrpcPsClient::SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode});
}

std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
int cmd_id = PS_GET_CACHE_THRESHOLD;
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
std::string cur_res = closure->get_response(i, cmd_id);
cache_thresholds[i] = std::stod(cur_res);
}
double sum_threshold = 0.0;
int count = 0;
for (auto t : cache_thresholds) {
if (t >= 0) {
sum_threshold += t;
++count;
}
}
if (count == 0) {
cache_threshold = 0;
} else {
cache_threshold = sum_threshold / count;
}
VLOG(1) << "client get cache threshold: " << cache_threshold;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(10800000);
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}

std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ class BrpcPsClient : public PSClient {
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);

std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) override;

std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold);

std::future<int32_t> SaveCache(uint32_t table_id, const std::string &path,
const std::string &mode) override;

std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) override;

void PrintQueueSize();
void PrintQueueSizeThread();

Expand Down
Loading

0 comments on commit cca57c4

Please sign in to comment.