Skip to content

Commit

Permalink
[PsCore] support ssd (#33031)
Browse files Browse the repository at this point in the history
* support ssd in PsCore

* remove log

* remove bz2

* defalut value

* code style

* parse table class

* code style

* add define
  • Loading branch information
Thunderbrook authored May 27, 2021
1 parent b425215 commit 988b5fe
Show file tree
Hide file tree
Showing 22 changed files with 914 additions and 108 deletions.
51 changes: 51 additions & 0 deletions cmake/external/rocksdb.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

INCLUDE(ExternalProject)

SET(ROCKSDB_SOURCES_DIR ${THIRD_PARTY_PATH}/rocksdb)
SET(ROCKSDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/rocksdb)
SET(ROCKSDB_INCLUDE_DIR "${ROCKSDB_INSTALL_DIR}/include" CACHE PATH "rocksdb include directory." FORCE)
SET(ROCKSDB_LIBRARIES "${ROCKSDB_INSTALL_DIR}/lib/librocksdb.a" CACHE FILEPATH "rocksdb library." FORCE)
SET(ROCKSDB_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
INCLUDE_DIRECTORIES(${ROCKSDB_INCLUDE_DIR})

ExternalProject_Add(
extern_rocksdb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${ROCKSDB_SOURCES_DIR}
GIT_REPOSITORY "https://github.com/facebook/rocksdb"
GIT_TAG v6.10.1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DWITH_BZ2=OFF
-DWITH_GFLAGS=OFF
-DCMAKE_CXX_FLAGS=${ROCKSDB_CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
# BUILD_BYPRODUCTS ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a
INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/
&& cp ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a ${ROCKSDB_LIBRARIES}
&& cp -r ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/include ${ROCKSDB_INSTALL_DIR}/
BUILD_IN_SOURCE 1
)

ADD_DEPENDENCIES(extern_rocksdb snappy)

ADD_LIBRARY(rocksdb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET rocksdb PROPERTY IMPORTED_LOCATION ${ROCKSDB_LIBRARIES})
ADD_DEPENDENCIES(rocksdb extern_rocksdb)

LIST(APPEND external_project_dependencies rocksdb)

5 changes: 5 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ if (WITH_PSCORE)

include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)

if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
endif()

if(WITH_XBYAK)
Expand Down
11 changes: 8 additions & 3 deletions paddle/fluid/distributed/fleet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,10 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return;
}

void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, mode);
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
Expand All @@ -429,8 +431,11 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {

void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret =
pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
communicator->_worker_ptr->load(table_id, path, std::to_string(mode));
// auto ret =
// pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/fleet.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class FleetWrapper {
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
void LoadModel(const std::string& path, const std::string& mode);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/distributed/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ ::std::future<int32_t> PsLocalClient::shrink(uint32_t table_id,
::std::future<int32_t> PsLocalClient::load(const std::string& epoch,
const std::string& mode) {
// TODO
// for (auto& it : _table_map) {
// load(it.first, epoch, mode);
//}
for (auto& it : _table_map) {
load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
// auto* table_ptr = table(table_id);
// table_ptr->load(epoch, mode);
auto* table_ptr = table(table_id);
table_ptr->load(epoch, mode);
return done();
}

Expand Down Expand Up @@ -245,7 +245,6 @@ ::std::future<int32_t> PsLocalClient::pull_sparse_ptr(char** select_values,
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient(
size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) {
VLOG(1) << "wxx push_sparse_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/distributed/service/ps_local_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ class PsLocalServer : public PSServer {
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string& ip, uint32_t port) { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}

private:
virtual int32_t initialize() { return 0; }
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class PSServer {

virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) final;
const std::vector<framework::ProgramDesc> &server_sub_program = {});

// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/distributed/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@ set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS $
cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)

cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc
sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator)
set(EXTERN_DEP "")
if(WITH_HETERPS)
set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
set(EXTERN_DEP rocksdb)
else()
set(TABLE_SRC common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
endif()

cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper
simple_threadpool xxhash generator ${EXTERN_DEP})

set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand Down
104 changes: 19 additions & 85 deletions paddle/fluid/distributed/table/common_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,83 +25,12 @@ class ValueBlock;
} // namespace distributed
} // namespace paddle

#define PSERVER_SAVE_SUFFIX ".shard"
using boost::lexical_cast;

namespace paddle {
namespace distributed {

enum SaveMode { all, base, delta };

struct Meta {
std::string param;
int shard_id;
std::vector<std::string> names;
std::vector<int> dims;
uint64_t count;
std::unordered_map<std::string, int> dims_map;

explicit Meta(const std::string& metapath) {
std::ifstream file(metapath);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
if (StartWith(line, "#")) {
continue;
}
auto pairs = paddle::string::split_string<std::string>(line, "=");
PADDLE_ENFORCE_EQ(
pairs.size(), 2,
paddle::platform::errors::InvalidArgument(
"info in %s except k=v, but got %s", metapath, line));

if (pairs[0] == "param") {
param = pairs[1];
}
if (pairs[0] == "shard_id") {
shard_id = std::stoi(pairs[1]);
}
if (pairs[0] == "row_names") {
names = paddle::string::split_string<std::string>(pairs[1], ",");
}
if (pairs[0] == "row_dims") {
auto dims_strs =
paddle::string::split_string<std::string>(pairs[1], ",");
for (auto& str : dims_strs) {
dims.push_back(std::stoi(str));
}
}
if (pairs[0] == "count") {
count = std::stoull(pairs[1]);
}
}
for (int x = 0; x < names.size(); ++x) {
dims_map[names[x]] = dims[x];
}
}

Meta(std::string param, int shard_id, std::vector<std::string> row_names,
std::vector<int> dims, uint64_t count) {
this->param = param;
this->shard_id = shard_id;
this->names = row_names;
this->dims = dims;
this->count = count;
}

std::string ToString() {
std::stringstream ss;
ss << "param=" << param << "\n";
ss << "shard_id=" << shard_id << "\n";
ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n";
ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n";
ss << "count=" << count << "\n";
return ss.str();
}
};

void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
const int64_t id, std::vector<std::vector<float>>* values) {
void CommonSparseTable::ProcessALine(const std::vector<std::string>& columns,
const Meta& meta, const int64_t id,
std::vector<std::vector<float>>* values) {
auto colunmn_size = columns.size();
auto load_values =
paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
Expand Down Expand Up @@ -134,8 +63,10 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
}
}

void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total) {
void CommonSparseTable::SaveMetaToText(std::ostream* os,
const CommonAccessorParameter& common,
const size_t shard_idx,
const int64_t total) {
// save meta
std::stringstream stream;
stream << "param=" << common.table_name() << "\n";
Expand All @@ -148,8 +79,10 @@ void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
os->write(stream.str().c_str(), sizeof(char) * stream.str().size());
}

int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool, const int mode) {
int64_t CommonSparseTable::SaveValueToText(std::ostream* os,
std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool,
const int mode, int shard_id) {
int64_t save_num = 0;
for (auto& table : block->values_) {
for (auto& value : table) {
Expand Down Expand Up @@ -186,10 +119,10 @@ int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
return save_num;
}

int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num,
const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) {
int64_t CommonSparseTable::LoadFromText(
const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) {
Meta meta = Meta(metapath);

int num_lines = 0;
Expand All @@ -198,7 +131,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
auto id = lexical_cast<int64_t>(values[0]);
auto id = lexical_cast<uint64_t>(values[0]);

if (id % pserver_num != pserver_id) {
VLOG(3) << "will not load " << values[0] << " from " << valuepath
Expand Down Expand Up @@ -388,8 +321,9 @@ int32_t CommonSparseTable::save(const std::string& dirname,
int64_t total_ins = 0;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// save values
auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode);
auto shard_save_num =
SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode, shard_id);
total_ins += shard_save_num;
}
vs->close();
Expand Down
Loading

0 comments on commit 988b5fe

Please sign in to comment.