Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rabit improvement] support rabit worker set/get configs from tracker #94

Closed
wants to merge 10 commits into from
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.0)
cmake_minimum_required(VERSION 3.3)

project(rabit VERSION 0.2.0)
project(rabit VERSION 0.2.1)

option(RABIT_BUILD_TESTS "Build rabit tests" OFF)
option(RABIT_BUILD_MPI "Build MPI" OFF)
option(RABIT_BUILD_DMLC "Include DMLC_CORE in build" ON)
option(RABIT_BUILD_DMLC "Include DMLC_CORE in build" OFF)

add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ else
endif

export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -std=c++11
export CFLAGS = -O3 $(WARNFLAGS) -I $(DMLC)/include -I include/
export CFLAGS = -O3 -g $(WARNFLAGS) -I $(DMLC)/include -I include/
export LDFLAGS =-Llib

#download mpi
Expand Down
6 changes: 4 additions & 2 deletions include/rabit/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
* \param root the root of process
*/
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root);
rbt_ulong size, int root,
const char* caller = __builtin_FUNCTION());
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
Expand All @@ -108,7 +109,8 @@ RABIT_DLL void RabitAllreduce(void *sendrecvbuf,
int enum_dtype,
int enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg);
void *prepare_arg,
const char* caller = __builtin_FUNCTION());

/*!
* \brief load latest check point
Expand Down
2 changes: 2 additions & 0 deletions include/rabit/internal/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class IEngine {
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg) = 0;
virtual void TrackerSetConfig(const std::string &key, const std::string &value) = 0;
virtual void TrackerGetConfig(const std::string& key, std::string* value) = 0;
};

/*! \brief initializes the engine module */
Expand Down
41 changes: 41 additions & 0 deletions include/rabit/internal/rabit-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}

inline void TrackerSetConfig(const std::string &key, const std::string &value) {
engine::GetEngine()->TrackerSetConfig(key, value);
}

inline void TrackerGetConfig(const std::string &key, std::string* value) {
engine::GetEngine()->TrackerGetConfig(key, value);
}

#ifndef RABIT_STRICT_CXX98_
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
Expand All @@ -188,6 +197,38 @@ inline void TrackerPrintf(const char *fmt, ...) {
msg.resize(strlen(msg.c_str()));
TrackerPrint(msg);
}

inline void TrackerSetConfig(const char *key, const char *value, ...) {
const int kPrintBuffer = 1 << 10;
std::string k(kPrintBuffer, '\0'), v(kPrintBuffer, '\0');

va_list args1, args2;
va_start(args1, key);
va_start(args2, value);
vsnprintf(&k[0], kPrintBuffer, key, args1);
vsnprintf(&v[0], kPrintBuffer, value, args2);
va_end(args1);
va_end(args2);
k.resize(strlen(k.c_str()));
v.resize(strlen(v.c_str()));
engine::GetEngine()->TrackerSetConfig(k, v);
}

inline void TrackerGetConfig(const char *key, char* value, ...) {
const int kPrintBuffer = 1 << 10;
std::string k(kPrintBuffer, '\0'), v(kPrintBuffer, '\0');

va_list args1, args2;
va_start(args1, key);
va_start(args2, value);
vsnprintf(&k[0], kPrintBuffer, key, args1);
vsnprintf(&v[0], kPrintBuffer, value, args2);
va_end(args1);
va_end(args2);
k.resize(strlen(k.c_str()));
v.resize(strlen(v.c_str()));
engine::GetEngine()->TrackerGetConfig(k, &v);
}
#endif // RABIT_STRICT_CXX98_
// load latest check point
inline int LoadCheckPoint(Serializable *global_model,
Expand Down
25 changes: 25 additions & 0 deletions include/rabit/rabit.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ inline std::string GetProcessorName();
* \param msg the message to be printed
*/
inline void TrackerPrint(const std::string &msg);
/*!
* \brief save config to tracker,
* \param key configuration key
* \param value value of config
*/
inline void TrackerSetConfig(const std::string &key, const std::string &value);
/*!
* \brief get config to tracker,
* \param key configuration key
* \param value value of config
*/
inline void TrackerGetConfig(const std::string &key, std::string* value);

#ifndef RABIT_STRICT_CXX98_
/*!
* \brief prints the msg to the tracker, this function may not be available
Expand All @@ -108,6 +121,18 @@ inline void TrackerPrint(const std::string &msg);
* \param fmt the format string
*/
inline void TrackerPrintf(const char *fmt, ...);
/*!
* \brief save config to tracker,
* \param key configuration key
* \param value value of config
*/
inline void TrackerSetConfig(const char *key, const char *value, ...);
/*!
* \brief get config to tracker,
* \param key configuration key
* \param value value of config
*/
inline void TrackerGetConfig(const char *key, char* value, ...);
#endif // RABIT_STRICT_CXX98_
/*!
* \brief broadcasts a memory region to every node from the root
Expand Down
16 changes: 8 additions & 8 deletions scripts/travis_runtest.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

make -f test.mk model_recover_10_10k || exit -1
make -f test.mk model_recover_10_10k_die_same || exit -1
make -f test.mk model_recover_10_10k_die_hard || exit -1
make -f test.mk local_recover_10_10k || exit -1
make -f test.mk lazy_recover_10_10k_die_hard || exit -1
make -f test.mk lazy_recover_10_10k_die_same || exit -1
make -f test.mk ringallreduce_10_10k || exit -1
make -f test.mk pylocal_recover_10_10k || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k_die_same || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 model_recover_10_10k_die_hard || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 local_recover_10_10k || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 lazy_recover_10_10k_die_hard || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 lazy_recover_10_10k_die_same || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 ringallreduce_10_10k || exit -1
make -f test.mk RABIT_BUILD_DMLC=1 pylocal_recover_10_10k || exit -1
18 changes: 18 additions & 0 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ void AllreduceBase::Shutdown(void) {
sock_listen.Close();
utils::TCPSocket::Finalize();
}

void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
Expand All @@ -144,6 +145,23 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.SendStr(msg);
tracker.Close();
}

void AllreduceBase::TrackerSetConfig(const std::string &key, const std::string &value) {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("set"));
tracker.SendStr(key);
tracker.SendStr(value);
tracker.Close();
}

void AllreduceBase::TrackerGetConfig(const std::string &key, std::string* value) {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("get"));
tracker.SendStr(key);
tracker.RecvStr(value);
tracker.Close();
}

// util to parse data with unit suffix
inline size_t ParseUnit(const char *name, const char *val) {
char unit;
Expand Down
3 changes: 3 additions & 0 deletions src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class AllreduceBase : public IEngine {
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
virtual void TrackerSetConfig(const std::string &key, const std::string &value);
virtual void TrackerGetConfig(const std::string& key, std::string* value);

/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
Expand Down
56 changes: 55 additions & 1 deletion src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true");
utils::Printf("[%d] load checkpoint global %ld version %d\n", rank,
global_checkpoint.length(), version_number);
return version_number;
} else {
// reset result buffer
Expand Down Expand Up @@ -296,7 +298,6 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
if (lazy_checkpt) {
global_lazycheck = global_model;
} else {
printf("[%d] save global checkpoint #%d \n", this->rank, version_number);
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
Expand Down Expand Up @@ -547,6 +548,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
{
// get the shortest distance to the request point
std::vector<std::pair<int, size_t> > dist_in, dist_out;

ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
&dist_in, &dist_out, ShortestDist);
if (succ != kSuccess) return succ;
Expand Down Expand Up @@ -713,6 +715,51 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
return kSuccess;
}
/*!
* \brief try to fetch allreduce/broadcast results from rest of nodes
* as collaberative function called by all nodes, only requester node
* will pass seqno to rest of nodes and reconstruct/backfill sendrecvbuf_
* of specific seqno from other nodes.
*/
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCache(void* sendrecvbuf, bool requester) {
RecoverType role = requester ? kRequestData : kHaveData;
ReturnType succ;
// recover global checkpoint
size_t size = this->global_checkpoint.length();
int recv_link;
std::vector<bool> req_in;
succ = TryDecideRouting(role, &size, &recv_link, &req_in);
if (succ != kSuccess) return succ;
// there is no checkpoints, which might be okay as long as resbuf has allreduce cache
//if (size == 0) return kSuccess;

//TODO: run allreduce min and populate restored sequence counter to kHaveData hosts
int a = 0;
size_t s = 0;
void* buf = resbuf.Query(a, &s);

// get size of allreduce buf from other
ReturnType ret = TryRecoverData(role, &s, sizeof(size_t), recv_link, req_in);
// for requester, allocate cache and push into resbuf
if(requester){
utils::Printf("[%d] recover resbuf %d size %d \n", rank, a, s);
buf = resbuf.AllocTemp(s, 1);
resbuf.PushTemp(this->seq_counter, s, 1);
}
// backfill result from other hosts
ret = TryRecoverData(role, buf, s, recv_link, req_in);

if(requester){
//copy resbuf of seq_counter to only requester sendrecvbuf
//as other workers sendrecvbuf might point to other allreduce functions
//with different code path
memccpy(sendrecvbuf, buf, s, 1);
}

//TODO: consider right return type
return ret;
}

/*!
* \brief try to load check point
*
Expand Down Expand Up @@ -777,6 +824,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
global_checkpoint.resize(size);
}
if (size == 0) return kSuccess;
utils::Printf("[%d] load checkpoint size %d seq %d\n", rank, size, seq_counter);
return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in);
}
/*!
Expand Down Expand Up @@ -848,9 +896,12 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {

if (flag != 0) {
utils::Assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
}

//utils::Printf("[%d] flag %d, seqno %d\n", rank, flag, seqno);
// request
ActionSummary req(flag, seqno);
while (true) {
Expand All @@ -867,6 +918,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
// if we requested checkpoint, we are free to go
if (req.check_point()) return true;
} else if (act.load_check()) {
// check cache
// if there is only check_ack and load_check, do load_check
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
// if requested load check, then misson complete
Expand Down Expand Up @@ -894,6 +946,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (act.load_check()) {
// all the nodes called load_check, this is an incomplete action
if (!act.diff_seq()) return false;
// load cache stored from other node to local TODO: consider return type
//TryLoadCache(buf, req.load_check());
// load check have higher priority, do load_check
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
// if requested load check, then misson complete
Expand Down
13 changes: 13 additions & 0 deletions src/allreduce_robust.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,19 @@ class AllreduceRobust : public AllreduceBase {
* \sa ReturnType
*/
ReturnType TryLoadCheckPoint(bool requester);

/*!
* \brief try to load cache
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to load the check point
* other nodes acts as collaborative roles to complete this request
* \param buf the buffer to store the result, this parameter is only used when current node is requester
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryLoadCache(void *buf, bool requester);
/*!
* \brief try to get the result of operation specified by seqno
*
Expand Down
8 changes: 8 additions & 0 deletions src/engine_empty.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ class EmptyEngine : public IEngine {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
virtual void TrackerSetConfig(const std::string &key, const std::string &value) {
// simply print information into the tracker
utils::Printf("%s-%s", key.c_str(), value.c_str());
}
virtual void TrackerGetConfig(const std::string& key, std::string* value) {
// simply print information into the tracker
utils::Printf("%s", key.c_str());
}

private:
int version_number;
Expand Down
9 changes: 9 additions & 0 deletions src/engine_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ class MPIEngine : public IEngine {
utils::Printf("%s", msg.c_str());
}
}
virtual void TrackerSetConfig(const std::string &key, const std::string &value) {
// simply print information into the tracker
// TODO(chen qin): figure out how to support MPI
utils::Printf("%s-%s", key.c_str(), value.c_str());
}
virtual void TrackerGetConfig(const std::string& key, std::string* value) {
// simply print information into the tracker
utils::Printf("%s", key.c_str());
}

private:
int version_number;
Expand Down
10 changes: 9 additions & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
RABIT_BUILD_DMLC = 0

ifeq ($(RABIT_BUILD_DMLC),1)
DMLC=../dmlc-core
else
DMLC=../../dmlc-core
endif

MPICXX=../mpich/bin/mpicxx
export LDFLAGS= -L../lib -pthread -lm
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -I ../dmlc-core/include -std=c++11
export CFLAGS = -Wall -O3 -msse3 -g -Wno-unknown-pragmas -fPIC -I../include -I $(DMLC)/include -std=c++11

OS := $(shell uname)

Expand Down
Loading