Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#663 from graphcore/develop_to_develop…
Browse files Browse the repository at this point in the history
…-ipu_bfd54663b9bb97a3653dee75fded84e3212802ba

[AUTO PR] Pulling [develop] into [develop-ipu] [2022-04-22 03:05:05+08:00]
  • Loading branch information
gglin001 authored Apr 24, 2022
2 parents bfd5466 + c56fffb commit 00edda1
Show file tree
Hide file tree
Showing 149 changed files with 4,796 additions and 1,241 deletions.
6 changes: 5 additions & 1 deletion cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ function(select_nvcc_arch_flags out_variable)
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
set(cuda_arch_bin "80")
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.1) # CUDA 11.0
set(cuda_arch_bin "80")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.1+
set(cuda_arch_bin "80 86")
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ add_definitions(-w)
######################################
include(ExternalProject)
set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN)
set(CINN_GIT_TAG eedb801ca39bfc6b9621bc76c24a0bf98cb8404b)
set(CINN_GIT_TAG release/v0.2)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION}
-DWITH_CUDA=${WITH_GPU}
-DWITH_CUDNN=${WITH_GPU}
Expand Down
6 changes: 2 additions & 4 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,8 @@ if (WITH_PSCORE)
include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)

if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()

if(WITH_XBYAK)
Expand Down
70 changes: 70 additions & 0 deletions paddle/fluid/distributed/common/topk_calculator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <queue>
#include <unordered_map>

namespace paddle {
namespace distributed {
class TopkCalculator {
public:
TopkCalculator(int shard_num, size_t k)
: _shard_num(shard_num), _total_max_size(k) {
_shard_max_size = _total_max_size / shard_num;
_shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1;
for (int i = 0; i < shard_num; ++i) {
_mpq.emplace(i, std::priority_queue<double, std::vector<double>,
std::greater<double>>());
}
}
~TopkCalculator() {}
bool push(int shard_id, double value) {
if (_mpq.find(shard_id) == _mpq.end()) {
return false;
}
auto &pq = _mpq[shard_id];
if (pq.size() < _shard_max_size) {
pq.push(value);
} else {
if (pq.top() < value) {
pq.pop();
pq.push(value);
}
}
return true;
}
// TODO 再进行一次堆排序merge各个shard的结果
int top() {
double total = 0;
for (const auto &item : _mpq) {
auto &pq = item.second;
if (!pq.empty()) {
total += pq.top();
}
}
return total / _shard_num;
}

private:
std::unordered_map<int, std::priority_queue<double, std::vector<double>,
std::greater<double>>>
_mpq;
int _shard_num;
size_t _total_max_size;
size_t _shard_max_size;
};

} // namespace distributed
} // namespace paddle
6 changes: 5 additions & 1 deletion paddle/fluid/distributed/ps/service/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})

set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
if(WITH_HETERPS)
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb)
else()
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
endif()

brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
Expand Down
76 changes: 76 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,82 @@ std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}

std::future<int32_t> BrpcPsClient::CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle";
return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold});
}

std::future<int32_t> BrpcPsClient::CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path";
std::vector<std::string> param;
param.push_back(path);
param.push_back(mode);
param.push_back(cache_threshold);
for (size_t i = 0; i < tables.size(); i++) {
param.push_back(std::to_string(tables[i]));
}
return SendSaveCmd(0, PS_CACHE_SHUFFLE, param);
}

std::future<int32_t> BrpcPsClient::SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode});
}

std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
int cmd_id = PS_GET_CACHE_THRESHOLD;
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
std::string cur_res = closure->get_response(i, cmd_id);
cache_thresholds[i] = std::stod(cur_res);
}
double sum_threshold = 0.0;
int count = 0;
for (auto t : cache_thresholds) {
if (t >= 0) {
sum_threshold += t;
++count;
}
}
if (count == 0) {
cache_threshold = 0;
} else {
cache_threshold = sum_threshold / count;
}
VLOG(1) << "client get cache threshold: " << cache_threshold;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(10800000);
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}

std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ class BrpcPsClient : public PSClient {
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);

std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) override;

std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold);

std::future<int32_t> SaveCache(uint32_t table_id, const std::string &path,
const std::string &mode) override;

std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) override;

void PrintQueueSize();
void PrintQueueSizeThread();

Expand Down
Loading

0 comments on commit 00edda1

Please sign in to comment.