Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 3, 2015
1 parent 42d5c98 commit 55fb0fc
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/cxxnet_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,6 @@ class CXXNetLearnTask {
os << rabit::GetWorldSize();
this->SetParam("dist_num_worker", os.str().c_str());
}
if (device == "gpu:rank") {
std::ostringstream os;
os << "gpu:" << rabit::GetRank();
device = os.str();
}
#endif
dmlc::Stream *cfg = dmlc::Stream::Create(argv[1], "r");
{
Expand Down Expand Up @@ -231,7 +226,16 @@ class CXXNetLearnTask {
if (reset_net_type != -1) {
net_type = reset_net_type;
}
int rank = 0;
#if MSHADOW_RABIT_PS
rank = rabit::GetRank();
#endif
nnet::INetTrainer *net;
if (device == "gpu:rank") {
std::ostringstream os;
os << "gpu:" << rank;
device = os.str();
}
if (!strncmp(device.c_str(), "gpu", 3)) {
#if MSHADOW_USE_CUDA
net = nnet::CreateNet<mshadow::gpu>(net_type);
Expand Down

0 comments on commit 55fb0fc

Please sign in to comment.