diff --git a/tracker/dmlc_tracker/local.py b/tracker/dmlc_tracker/local.py index 0fa00faeec..6e4af12575 100644 --- a/tracker/dmlc_tracker/local.py +++ b/tracker/dmlc_tracker/local.py @@ -9,7 +9,7 @@ from threading import Thread from . import tracker -def exec_cmd(cmd, role, taskid, pass_env): +def exec_cmd(cmd, num_attempt, role, taskid, pass_env): """Execute the command line command.""" if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt': cmd[0] = './' + cmd[0] @@ -21,12 +21,11 @@ def exec_cmd(cmd, role, taskid, pass_env): env['DMLC_TASK_ID'] = str(taskid) env['DMLC_ROLE'] = role env['DMLC_JOB_CLUSTER'] = 'local' - num_retry = env.get('DMLC_NUM_ATTEMPT', 0) - #overwrite default num of retry with commandline value - for param in cmd: - if param.startswith('DMLC_NUM_ATTEMPT'): - num_retry = int(param.split('=')[1]) + # backward compatibility + num_retry = env.get('DMLC_NUM_ATTEMPT', num_attempt) + num_trial = 0 + logging.debug('num of retry %d',num_retry) while True: @@ -38,17 +37,11 @@ def exec_cmd(cmd, role, taskid, pass_env): logging.debug('Thread %d exit with 0', taskid) return else: + num_trial += 1 num_retry -= 1 - newcmd = [] + if num_retry >= 0: - # failure trail increase by 1 and restart failed worker - for arg in cmd: - if arg.startswith('rabit_num_trial'): - val = arg.split('=')[1] - arg = arg.replace(val, str(int(val)+1)) - newcmd.append(arg) - cmdline = ' '.join(newcmd) - cmd = newcmd + cmdline = ' '.join(cmd + ['DMLC_NUM_ATTEMPT=' + str(num_trial)]) continue if os.name == 'nt': sys.exit(-1) @@ -75,7 +68,7 @@ def mthread_submit(nworker, nserver, envs): role = 'worker' else: role = 'server' - procs[i] = Thread(target=exec_cmd, args=(args.command, role, i, envs)) + procs[i] = Thread(target=exec_cmd, args=(args.command, args.local_num_attempt, role, i, envs)) procs[i].setDaemon(True) procs[i].start() diff --git a/tracker/dmlc_tracker/opts.py b/tracker/dmlc_tracker/opts.py index 443c92a179..d642b004fb 100644 --- a/tracker/dmlc_tracker/opts.py +++ b/tracker/dmlc_tracker/opts.py @@ -162,6 +162,8 @@ def get_opts(args=None): parser.add_argument('--kube-server-template', default=None, type=str, help=('Manifest template for servers. Used only in Kubernetes mode.' + 'Can be used to override defaults.')) + parser.add_argument('--local-num-attempt', default=0, type=int, + help=('Number of attempt local tracker can restart slave.')) (args, unknown) = parser.parse_known_args(args) args.command += unknown