Skip to content

Commit

Permalink
[auto config] Add progress display
Browse files Browse the repository at this point in the history
  • Loading branch information
xysheng-baidu committed Dec 8, 2023
1 parent d282b29 commit 7e357b4
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 1 deletion.
46 changes: 45 additions & 1 deletion python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

from paddle.distributed.launch.main import ctx

logger = logging.getLogger('auto_tuner')

_PRUNE_FUNC = []
_PRUNE_HISTORY_FUNC = []


def log_pruned_info(cur_cfg, pruned_reason):
pruned_strategy = "DP{}_MP{}_PP{}_VPP_{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format(
cur_cfg["dp_degree"],
cur_cfg["mp_degree"],
cur_cfg["pp_degree"],
cur_cfg["vpp_degree"],
cur_cfg["sharding_degree"],
cur_cfg["sharding_stage"],
cur_cfg["micro_batch_size"],
cur_cfg["use_recompute"],
cur_cfg["recompute_granularity"],
)
ctx.logger.info(
f"Strategy {pruned_strategy} has been pruned that {pruned_reason}"
)
logger.info(
f"Strategy {pruned_strategy} has been pruned that {pruned_reason}"
)


def same_cfgs_beside(attr, cur_cfg, history_cfgs=[]):
"""
Compare the current configuration with the history configuration,
Expand All @@ -34,6 +61,7 @@ def same_cfgs_beside(attr, cur_cfg, history_cfgs=[]):
results.append(cfg)
else:
same = True

return results


Expand Down Expand Up @@ -215,6 +243,8 @@ def prune_by_vpp_history(tuner_cfg, cur_cfg, history_cfgs=[]):
cfg["vpp_degree"] > vpp_degree
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"vpp_degree {vpp_degree} may cause oom because { cfg['vpp_degree']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
return True
return False

Expand Down Expand Up @@ -282,13 +312,17 @@ def prune_by_mbs_history(tuner_cfg, cur_cfg, history_cfgs=[]):
cfg["micro_batch_size"] > micro_batch_size
and cfg.get("time", -1) > 0
):
pruned_reason = f"micro_batch_size {micro_batch_size} may be slower because {cfg['micro_batch_size']} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
return True

# memory prune
if (
cfg["micro_batch_size"] < micro_batch_size
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"micro_batch_size {micro_batch_size} may cause oom because {cfg['micro_batch_size']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
return True
return False

Expand Down Expand Up @@ -343,20 +377,25 @@ def prune_by_sharding_history(tuner_cfg, cur_cfg, history_cfgs=[]):
sharding_stage = cur_cfg.get("sharding_stage", None)
if sharding_stage is None:
return False

cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
if (
cfg["sharding_stage"] < sharding_stage
and cfg.get("time", -1) > 0
):
pruned_reason = f"sharding_stage {sharding_stage} may be slower because {cfg['sharding_stage'] } has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
return True

# memory prune
if (
cfg["sharding_stage"] > sharding_stage
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"sharding_stage {sharding_stage} may cause oom because {cfg['sharding_stage']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
return True

if sharding_degree == 1:
Expand Down Expand Up @@ -412,20 +451,25 @@ def prune_by_recompute_history(tuner_cfg, cur_cfg, history_cfgs=[]):
and use_recompute
and cfg.get("time", -1) > 0
):
pruned_reason = f"use_recompute {use_recompute} may be slower because {cfg['use_recompute']} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
return True

if (
cfg["use_recompute"]
and not use_recompute
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"use_recompute {use_recompute} may cause oom because {cfg['use_recompute']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
return True

if not use_recompute:
cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs)
if cfgs:
pruned_reason = f"recompute_granularity {cfg['recompute_granularity']} invalid because use_recompute is {use_recompute}."
log_pruned_info(cur_cfg, pruned_reason)
return True

return False


Expand Down
74 changes: 74 additions & 0 deletions python/paddle/distributed/launch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from .context import Context

ctx = None


def launch():
"""
Expand Down Expand Up @@ -284,6 +286,7 @@ def launch():
"""

# initialize the context to run
global ctx
ctx = Context()

if ctx.is_legacy_mode():
Expand Down Expand Up @@ -435,6 +438,7 @@ def launch():
recorder = HistoryRecorder()

job_id = 0
error_task_nums = 0
ctx.args.max_restart = -1
raw_ctx = copy.deepcopy(ctx)

Expand Down Expand Up @@ -703,6 +707,42 @@ def launch():
add_overlap_performance(
cur_cfg, tuner_cfg, recorder.history
)
has_error = cur_cfg["has_error"]
if has_error:
error_task_nums += 1
error_info = cur_cfg["error_info"]
task_nums = len(auto_tuner.algo.all_tasks)
cur_task_id = auto_tuner.algo.idx
ctx.logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id)
* max_time_per_task
/ 60,
2,
),
)
)
logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id)
* max_time_per_task
/ 60,
2,
),
)
)
recorder.store_history(history_file_path)
# generate a new config
new_cfg = auto_tuner.search_once()
Expand Down Expand Up @@ -862,6 +902,40 @@ def launch():
f"bw_{bw}_{tuner_cfg['metric_cfg']['name']}"
] = multi_dp_performace

cur_cfg["has_error"] = has_error
if has_error:
error_task_nums += 1
error_info = None
cur_cfg["error_info"] = error_info
task_nums = len(auto_tuner.algo.all_tasks)
cur_task_id = auto_tuner.algo.idx
ctx.logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id) * max_time_per_task / 60,
2,
),
)
)
logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id) * max_time_per_task / 60,
2,
),
)
)

# sync for single dp
if sorted_ips:
master_ip = sorted_ips[0]
Expand Down

0 comments on commit 7e357b4

Please sign in to comment.