-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ParameterServerController for parameter server python api #1051
Changes from 6 commits
f3c61cb
f9a65b0
95f20b9
cfbb4c4
7783982
93e74f8
3f6c2b3
5aaaef4
b1eeb2e
d32c7a6
aa9f516
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "PServerController.h" | ||
|
||
namespace paddle { | ||
|
||
PServerController::PServerController(const ParameterServerConfig& config) { | ||
// round robin to load balance RDMA server ENGINE | ||
std::vector<std::string> devices; | ||
int rdmaCpu = 0; | ||
int onlineCpus = rdma::numCpus(); | ||
int numPorts = config.ports_num() + config.ports_num_for_sparse(); | ||
|
||
if (config.nics().empty()) { | ||
pservers_.resize(numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
if (config.rdma_tcp() == "rdma") { | ||
pservers_[i].reset( | ||
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers_[i].reset( | ||
new ParameterServer2(std::string(), config.port() + i)); | ||
} | ||
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server" | ||
<< config.port() + i; | ||
} | ||
} else { | ||
str::split(config.nics(), ',', &devices); | ||
pservers_.resize(devices.size() * numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
for (size_t j = 0; j < devices.size(); ++j) { | ||
if (config.rdma_tcp() == "rdma") { | ||
pservers_[i * devices.size() + j].reset(new ParameterServer2( | ||
getIpAddr(devices[j]), config.port() + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers_[i * devices.size() + j].reset( | ||
new ParameterServer2(getIpAddr(devices[j]), config.port() + i)); | ||
} | ||
CHECK(pservers_[i * devices.size() + j]->init()) | ||
<< "Fail to initialize parameter server" << devices[j] | ||
<< config.port() + i; | ||
} | ||
} | ||
} | ||
} | ||
|
||
PServerController::~PServerController() { this->join(); } | ||
|
||
ParameterServerConfig* PServerController::initConfigByGflags() { | ||
ParameterServerConfig* config = new ParameterServerConfig(); | ||
config->set_nics(FLAGS_nics); | ||
config->set_port(FLAGS_port); | ||
config->set_ports_num(FLAGS_ports_num); | ||
config->set_rdma_tcp(FLAGS_rdma_tcp); | ||
return config; | ||
} | ||
|
||
PServerController* PServerController::createByGflags() { | ||
auto& pServerConfig = *paddle::PServerController::initConfigByGflags(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里有内存泄露 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里直接用栈变量不可以么?也就是
这样。同时,initConfigByGflags 只被 createByGFlags 调用,没必要extract成一个private的static member function了吧。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return create(pServerConfig); | ||
} | ||
|
||
PServerController* PServerController::create( | ||
const ParameterServerConfig& config) { | ||
return new PServerController(config); | ||
} | ||
|
||
void PServerController::start() { | ||
LOG(INFO) << "pserver sizes : " << pservers_.size(); | ||
int i = 0; | ||
for (const auto& pserver : pservers_) { | ||
LOG(INFO) << "pserver started : " << i; | ||
pserver->start(); | ||
i++; | ||
} | ||
} | ||
|
||
void PServerController::join() { | ||
LOG(INFO) << "pserver sizes : " << pservers_.size(); | ||
int i = 0; | ||
for (const auto& pserver : pservers_) { | ||
LOG(INFO) << "pserver join : " << i; | ||
pserver->join(); | ||
i++; | ||
} | ||
} | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "ParameterServer2.h" | ||
#include "ParameterServerConfig.pb.h" | ||
#include "RDMANetwork.h" | ||
#include "paddle/utils/StringUtil.h" | ||
|
||
namespace paddle { | ||
|
||
class PServerController { | ||
public: | ||
DISABLE_COPY(PServerController); | ||
|
||
/** | ||
* @brief Ctor, Create a PServerUtil from ParameterServerConfig. | ||
*/ | ||
explicit PServerController(const ParameterServerConfig& config); | ||
|
||
/** | ||
* @brief Dtor. | ||
*/ | ||
~PServerController(); | ||
|
||
/** | ||
* @brief create PServerUtil from gflags, this is used for | ||
* compatibility with the old usage of configuration by gflags. | ||
*/ | ||
static PServerController* createByGflags(); | ||
|
||
/** | ||
* @brief create PServerUtil with ParameterServerConfig, remove gflags | ||
* from ParameterServer. Init all pservers thread according to the config. | ||
*/ | ||
static PServerController* create(const ParameterServerConfig& config); | ||
|
||
/** | ||
* @brief start all pserver thread in this PServerUtil. | ||
*/ | ||
void start(); | ||
|
||
/** | ||
* @brief join and wait for all pserver thread in this PServerUtil. | ||
*/ | ||
void join(); | ||
|
||
private: | ||
std::vector<std::shared_ptr<ParameterServer2>> pservers_; | ||
|
||
/** | ||
* @brief create ParameterServerConfig from gflags, this is used for | ||
* compatibility with the old usage of configuration by gflags. | ||
*/ | ||
static ParameterServerConfig* initConfigByGflags(); | ||
}; | ||
|
||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量不要用C++的new关键词。
最省事的方法是:
auto config = std::make_shared<ParameterServerConfig>();
或者是
虽然目前没有make_unique,不过回头我加上吧。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对了,勉强在Cpp里面和java
new
语意一致的东西是std::make_shared<类型名>(参数)
。只是std::make_unique
会快一点。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
赞,多谢,已经修改了,不过make_unique打算如何引入?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接加上这个函数也可以。。http://stackoverflow.com/questions/17902405/how-to-implement-make-unique-function-in-c11
判断一下C++版本, if __cplusplus != 14,那么就加上make_unique。