Skip to content

Commit

Permalink
Merge pull request #512 from chenqin/master
Browse files Browse the repository at this point in the history
[rabit harden] remove rabit_num_trail from dmlc-core
  • Loading branch information
szha authored Mar 14, 2019
2 parents 9acdddd + 7e79220 commit e7d2014
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
25 changes: 9 additions & 16 deletions tracker/dmlc_tracker/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from threading import Thread
from . import tracker

def exec_cmd(cmd, role, taskid, pass_env):
def exec_cmd(cmd, num_attempt, role, taskid, pass_env):
"""Execute the command line command."""
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0]
Expand All @@ -21,12 +21,11 @@ def exec_cmd(cmd, role, taskid, pass_env):
env['DMLC_TASK_ID'] = str(taskid)
env['DMLC_ROLE'] = role
env['DMLC_JOB_CLUSTER'] = 'local'
num_retry = env.get('DMLC_NUM_ATTEMPT', 0)

#overwrite default num of retry with commandline value
for param in cmd:
if param.startswith('DMLC_NUM_ATTEMPT'):
num_retry = int(param.split('=')[1])
# backward compatibility
num_retry = env.get('DMLC_NUM_ATTEMPT', num_attempt)
num_trial = 0

logging.debug('num of retry %d',num_retry)

while True:
Expand All @@ -38,17 +37,11 @@ def exec_cmd(cmd, role, taskid, pass_env):
logging.debug('Thread %d exit with 0', taskid)
return
else:
num_trial += 1
num_retry -= 1
newcmd = []

if num_retry >= 0:
# failure trail increase by 1 and restart failed worker
for arg in cmd:
if arg.startswith('rabit_num_trial'):
val = arg.split('=')[1]
arg = arg.replace(val, str(int(val)+1))
newcmd.append(arg)
cmdline = ' '.join(newcmd)
cmd = newcmd
cmdline = ' '.join(cmd + ['DMLC_NUM_ATTEMPT=' + str(num_trial)])
continue
if os.name == 'nt':
sys.exit(-1)
Expand All @@ -75,7 +68,7 @@ def mthread_submit(nworker, nserver, envs):
role = 'worker'
else:
role = 'server'
procs[i] = Thread(target=exec_cmd, args=(args.command, role, i, envs))
procs[i] = Thread(target=exec_cmd, args=(args.command, args.local_num_attempt, role, i, envs))
procs[i].setDaemon(True)
procs[i].start()

Expand Down
2 changes: 2 additions & 0 deletions tracker/dmlc_tracker/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def get_opts(args=None):
parser.add_argument('--kube-server-template', default=None, type=str,
help=('Manifest template for servers. Used only in Kubernetes mode.' +
'Can be used to override defaults.'))
parser.add_argument('--local-num-attempt', default=0, type=int,
help=('Number of attempt local tracker can restart slave.'))
(args, unknown) = parser.parse_known_args(args)
args.command += unknown

Expand Down

0 comments on commit e7d2014

Please sign in to comment.