From 55fb0fcdee6c7d377b96a55271a565063d825a77 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 2 May 2015 22:01:06 -0700 Subject: [PATCH] fix --- src/cxxnet_main.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/cxxnet_main.cpp b/src/cxxnet_main.cpp index 82c438d0..275a3b21 100644 --- a/src/cxxnet_main.cpp +++ b/src/cxxnet_main.cpp @@ -90,11 +90,6 @@ class CXXNetLearnTask { os << rabit::GetWorldSize(); this->SetParam("dist_num_worker", os.str().c_str()); } - if (device == "gpu:rank") { - std::ostringstream os; - os << "gpu:" << rabit::GetRank(); - device = os.str(); - } #endif dmlc::Stream *cfg = dmlc::Stream::Create(argv[1], "r"); { @@ -231,7 +226,16 @@ class CXXNetLearnTask { if (reset_net_type != -1) { net_type = reset_net_type; } + int rank = 0; +#if MSHADOW_RABIT_PS + rank = rabit::GetRank(); +#endif nnet::INetTrainer *net; + if (device == "gpu:rank") { + std::ostringstream os; + os << "gpu:" << rank; + device = os.str(); + } if (!strncmp(device.c_str(), "gpu", 3)) { #if MSHADOW_USE_CUDA net = nnet::CreateNet(net_type);