From 1314aa88f47ee00df101acf1cd7d387dbe87a72d Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 14 Dec 2022 19:38:34 +0800 Subject: [PATCH] [AutoParallel] recompute tuning (#48608) * [AutoParallel] recompute tuning * fix conflict * update comment * bug fix * update rc algo * tiny fix * fix clear process_group * remove comment * update segment print * fix import OpRole * adapt amp pass and grad_clip pass for opt_tuner * update tuning config * fix import * annotate recompute info on ops and upgrade recompute pass * add op_namescope for seed op * record reserved vars * fix recompute var's dist_attr * fix strategy unittest * adapt for fp16 * update unittest * revert copy opt * update unittest * rename set_recompute_segments * fix unittest --- .../distributed/auto_parallel/constants.py | 6 +- .../distributed/auto_parallel/engine.py | 5 +- .../distributed/auto_parallel/strategy.py | 4 + .../auto_parallel/tuner/algorithms.py | 95 +++++- .../distributed/auto_parallel/tuner/config.py | 50 ++-- .../auto_parallel/tuner/optimization_tuner.py | 47 ++- .../auto_parallel/tuner/profiler.py | 22 +- .../paddle/distributed/auto_parallel/utils.py | 96 +++++-- .../distributed/passes/auto_parallel_amp.py | 6 + .../distributed/passes/auto_parallel_fp16.py | 9 +- .../passes/auto_parallel_grad_clip.py | 3 +- .../passes/auto_parallel_gradient_merge.py | 3 +- .../passes/auto_parallel_recompute.py | 270 ++++++++---------- .../unittests/auto_parallel/CMakeLists.txt | 2 + .../unittests/auto_parallel/get_gpt_model.py | 7 +- .../auto_parallel/optimization_tuner_api.py | 2 +- .../unittests/auto_parallel/test_strategy.py | 6 +- .../auto_parallel/test_tuning_recompute.py | 111 +++++++ 18 files changed, 495 insertions(+), 249 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 857245b9be425..ce72304dc75cd 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -54,7 +54,7 @@ def set_field_default_config(category, field, default_value): ######################################### RECOMPUTE = "recompute" set_field_default_config(RECOMPUTE, "enable", False) -set_field_default_config(RECOMPUTE, "checkpoints", None) +set_field_default_config(RECOMPUTE, "checkpoints", []) set_field_default_config(RECOMPUTE, "no_recompute_segments", []) set_field_default_config(RECOMPUTE, "enable_tuning", False) @@ -113,12 +113,10 @@ def set_field_default_config(category, field, default_value): # ######################################### TUNING = "tuning" set_field_default_config(TUNING, "enable", False) -set_field_default_config(TUNING, "batch_size", 1) -set_field_default_config(TUNING, "dataset", None) set_field_default_config(TUNING, "profile_start_step", 1) set_field_default_config(TUNING, "profile_end_step", 1) set_field_default_config(TUNING, "run_after_tuning", True) -set_field_default_config(TUNING, "verbose", True) +set_field_default_config(TUNING, "debug", False) ######################################### # dataset configuration diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 092212a87168b..dc7470283aef8 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -609,7 +609,9 @@ def _build(self, mode): if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - auto_utils.set_recompute_ckpts(self._model, self._strategy) + auto_utils.set_recompute_segments( + self._model, self._losses, self._strategy, serial_main_prog + ) self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, @@ -649,7 +651,6 @@ def _optimization_tuning(self, mode, dataset, batch_size): from .tuner.optimization_tuner import OptimizationTuner self._optimization_tuner = OptimizationTuner( - self._tuning.to_dict(), self._dist_contexts[mode], dataset, self._inputs_spec, diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 4d626bb6ae495..7e6b98665a8d0 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -73,6 +73,10 @@ def __deepcopy__(self, memo): setattr(result, k, copy.deepcopy(v, memo)) return result + def get(self, k, d=None): + result_dict = self.to_dict() + return result_dict.get(k, d) + class RecomputeConfig(BaseConfig): def __init__(self, config_dict=None): diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 8ce570d03c288..74e8f3e9ee3f1 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -16,7 +16,7 @@ import logging from abc import ABC, abstractmethod -from ..utils import get_logger +from ..utils import get_logger, is_recompute_op from .trial import OptimizationTunerTrial as Trial from .trial import TrialStatus @@ -54,7 +54,7 @@ def changed_configs(self): def collect_model_info(self, main_prog, startup_prog): """ Collect the model static info (from programs) that could be used to - pruning candidate trials and saving tuning time.For instance, + pruning candidate trials and saving tuning time. For instance, model info like number of model parameters and activation memory could be used to prune candidated trial and decide the next trial. """ @@ -116,7 +116,7 @@ def _init_spaces(self): self._max_stage = 3 self._trial_idx = 0 - stage_range = self._config.sharding.to_dict().get("tuning_range", None) + stage_range = self._config.sharding.get("tuning_range", None) if stage_range: assert set(stage_range).issubset( set([0, 1, 2, 3]) @@ -157,3 +157,92 @@ def update(self, results): ) else: self._trial_idx += 1 + + +@register_algor("recompute") +class ReccomputeCheckpointAlgorithm(AlgorithmBase): + def __init__(self, config): + super().__init__(config) + self._changed_configs = ["recompute"] + + def collect_model_info(self, main_prog, startup_prog): + segments = [] + for op in main_prog.global_block().ops: + if not is_recompute_op(op): + continue + + seg_name = op.attr('op_namescope') + if seg_name not in segments: + segments.append(seg_name) + + self._total_num_trial = len(segments) + self._tuning_segments = list(range(len(segments))) + self._trail_left = 0 + self._trail_right = len(segments) - 1 + self._trial_idx = int(0 + (len(segments) - 1) / 2) + + def _init_spaces(self): + self._recompute_mode = "all" + + def next_trial(self): + if self._trial_idx < self._total_num_trial: + if self._recompute_mode == "all": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + name = "trial-recompute-all-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "none": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.enable = False + name = "trial-recompute-none-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "part": + new_no_recompute = self._tuning_segments[: self._trial_idx] + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.no_recompute_segments.extend(new_no_recompute) + name = "trial-recompute-part-segments-idx{}".format( + self._trial_idx + ) + return Trial(new_strategy, name, self.changed_configs) + else: + return Trial(None, None, None, status=TrialStatus.STOPPED) + + def update(self, results): + + et = results.get("ErrorType", None) + if self._recompute_mode == "all": + if et and et == "ResourceExhaustedError": + self._trial_idx = self._total_num_trial + self._logger.info( + "Recompute all candidate segments is failed with OOM, please reduce model size or batch size." + ) + else: + self._recompute_mode = "none" + elif self._recompute_mode == "none": + if et and et == "ResourceExhaustedError": + self._recompute_mode = "part" + else: + self._trial_idx = self._total_num_trial + self._logger.info( + "Recompute is unnecessary for this model size, which will reduce the Throughtput." + ) + else: + if self._trail_left >= self._trail_right: + self._trial_idx = self._total_num_trial + elif et and et == "ResourceExhaustedError": + self._trail_left = self._trail_left + self._trail_right = self._trial_idx - 1 + self._trial_idx = int( + self._trail_left + + (self._trail_right - self._trail_left) / 2 + ) + else: + self._trail_left = self._trial_idx + 1 + self._trail_right = self._trail_right + self._trial_idx = int( + self._trail_left + + (self._trail_right - self._trail_left) / 2 + ) diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py index f47ec1ae2d041..78f94b87b360b 100644 --- a/python/paddle/distributed/auto_parallel/tuner/config.py +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -32,14 +32,11 @@ class TuningConfig: tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific """ - def __init__(self, user_config, strategy): + def __init__(self, strategy): if not isinstance(strategy, Strategy): raise TypeError("'strategy' must be object of class `Strategy`.") - if not user_config: - user_config = {} - self._tuning_passes_name = set() self._dist_strategy = copy.deepcopy(strategy) self._mode = None @@ -48,9 +45,9 @@ def __init__(self, user_config, strategy): self._project_dir = None self._max_num_trial = None self._early_stop = None - self._verbose = None + self._debug = None - self._initialize(user_config) + self._initialize() @property def mode(self): @@ -81,29 +78,25 @@ def early_stop(self): return self._early_stop @property - def verbose(self): - return self._verbose + def debug(self): + return self._debug @property def dist_strategy(self): return self._dist_strategy # initialize config with user define value or default value - def _initialize(self, user_config): - - self._mode = user_config.get("mode", "PROFILE") - - self._profile_start_step = user_config.get("profile_start_step", 10) - - self._profile_end_step = user_config.get("profile_end_step", 30) - - self._max_num_trial = user_config.get("max_num_trial", 50) - - self._early_stop = user_config.get("early_stop", None) + def _initialize(self): + tuning_strategy = self._dist_strategy.tuning - self._verbose = user_config.get("verbose", False) + self._mode = tuning_strategy.get("mode", "PROFILE") + self._profile_start_step = tuning_strategy.get("profile_start_step", 10) + self._profile_end_step = tuning_strategy.get("profile_end_step", 30) + self._max_num_trial = tuning_strategy.get("max_num_trial", 50) + self._early_stop = tuning_strategy.get("early_stop", None) + self._debug = tuning_strategy.get("debug", False) - project_dir = user_config.get("project_dir", None) + project_dir = tuning_strategy.get("project_dir", None) if not project_dir: project_dir = os.path.join(os.getcwd(), "OptimizationTuning") self._project_dir = project_dir @@ -116,15 +109,14 @@ def _initialize(self, user_config): # TODO distinguish different args of each passes self._tuning_passes_name.add(p) - config_name = p - p_dict = getattr(self._dist_strategy, config_name) - self.__dict__[config_name] = p_dict + p_strategy = getattr(self._dist_strategy, p) + self.__dict__[p] = p_strategy - # TODO verify the user defined configs - user_config_for_pass = user_config.get(p, None) - if user_config_for_pass: - for k, v in user_config_for_pass.items(): - self.__dict__[config_name][k] = v + # # TODO verify the user defined configs + # tuning_config_for_pass = tuning_strategy.get(p, None) + # if tuning_config_for_pass: + # for k, v in tuning_config_for_pass.items(): + # self.__dict__[p][k] = v # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned def __getattr__(self, item): diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 8a2867a315d3e..c3de081c752ba 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -33,6 +33,7 @@ from paddle.distributed.auto_parallel.process_group import ( clear_all_process_groups, get_all_process_groups, + new_process_group, ) from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.utils import ( @@ -40,7 +41,7 @@ set_grad_var_shape, ) from paddle.distributed.passes import PassContext, new_pass -from paddle.fluid import program_guard +from paddle.fluid import program_guard, unique_name from paddle.fluid.backward import append_backward from ..utils import get_logger @@ -109,7 +110,12 @@ def parse_results(results): # all env need to be start a new pass are member of dist context def _copy_context(ref_dist_context): + # clear all process groups and recover the world process group clear_all_process_groups() + ranks = [] + for process_mesh in ref_dist_context._process_meshes: + ranks.extend(process_mesh.processes) + new_process_group(list(set(ranks))) new_dist_context = DistributedContext() new_dist_context._serial_main_program = ( @@ -195,7 +201,6 @@ class OptimizationTuner: def __init__( self, - user_configs, dist_context, dataset, inputs_spec, @@ -204,7 +209,7 @@ def __init__( rank, ): - self._config = TuningConfig(user_configs, dist_context._strategy) + self._config = TuningConfig(dist_context.strategy) # should not modify dist context from calling function self._baseline_dist_context = _copy_context(dist_context) self._baseline_completer = Completer(self._baseline_dist_context) @@ -264,7 +269,7 @@ def _build_programs_without_optimization(self): ) self._baseline_dist_context._params_grads = params_grads - if self._config.verbose: + if self._config.debug: baseline_dir = os.path.join(self.project_dir, "baseline") if not os.path.exists(baseline_dir): pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True) @@ -299,7 +304,6 @@ def _apply_optimization(self, trial): config = copy.deepcopy(new_strategy.amp.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_context._params_grads - # TODO AMP Pass should not use loss var config["loss"] = dist_context.serial_loss config["input_data"] = ( @@ -312,13 +316,13 @@ def _apply_optimization(self, trial): auto_parallel_fp16_pass.apply( [main_program], [startup_program], pass_context ) - dist_context.serial_loss = auto_parallel_fp16_pass.get_loss() + dist_context._serial_loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply( [main_program], [startup_program], pass_context ) - dist_context.serial_loss = auto_parallel_amp_pass.get_loss() + dist_context._serial_loss = auto_parallel_amp_pass.get_loss() if new_strategy.recompute.enable: config = copy.deepcopy(new_strategy.recompute.to_dict()) @@ -345,9 +349,10 @@ def _apply_optimization(self, trial): # Generate optimizer # FIXME should be remove from apply pass after pass support optimizers with program_guard(dist_main_prog, dist_startup_prog): - optimizer_ops = dist_context.serial_optimizer.apply_gradients( - dist_params_grads - ) + with unique_name.guard("opt_"): + optimizer_ops = dist_context.serial_optimizer.apply_gradients( + dist_params_grads + ) completer.complete_update_annotation(dist_main_prog) # Do reshard process @@ -361,6 +366,13 @@ def _apply_optimization(self, trial): ) resharder.reshard() + config = {} + config["dist_context"] = dist_context + config["global_rank"] = self.rank + config["use_sharding"] = new_strategy.sharding.enable + dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) + dp_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) + if new_strategy.sharding.enable: config = copy.deepcopy(new_strategy.sharding.to_dict()) config["dist_context"] = dist_context @@ -372,6 +384,17 @@ def _apply_optimization(self, trial): auto_parallel_sharding_pass.apply( [dist_main_prog], [dist_startup_prog], pass_context ) + dist_params_grads = pass_context.get_attr("params_grads") + + # gradient clip + config = copy.deepcopy(new_strategy.sharding.to_dict()) + config["dist_context"] = dist_context + config["params_grads"] = dist_params_grads + config["rank_id"] = self.rank + auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config) + auto_parallel_clip_pass.apply( + [dist_main_prog], [dist_startup_prog], pass_context + ) if new_strategy.gradient_merge.enable: config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) @@ -488,7 +511,7 @@ def _profile_trial(self, trial): with open(ctx_path, 'wb') as f: pickle.dump(profile_ctx, f, protocol=4) - if self._config.verbose: + if self._config.debug: debug_program(trial.main_program, trial_dir, "main_program") debug_program(trial.startup_program, trial_dir, "startup_program") @@ -581,7 +604,7 @@ def clear(self): Clear the temporary file generated in tuning procedure. """ # TODO clear up zombie process created by tuning - if not self._config.verbose: + if not self._config.debug: for trial in self._finished_trials: trial_dir = self._get_trial_dir(trial) shutil.rmtree(trial_dir, ignore_errors=True) diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 4a4dfea763157..cdd4a0045c8c9 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -89,7 +89,7 @@ def init_process_groups(group_map, rank): # TODO should instantiate global group first all_process_groups = get_all_process_groups() for process_group in all_process_groups: - if process_group.id == 0 or rank not in process_group.ranks: + if rank not in process_group.ranks: continue print(process_group) process_group.instantiate() @@ -173,10 +173,11 @@ def init_comm(profile_ctx): genv = _get_global_env() genv = dist_env print( - "current process rank: {}, device_id: {}, ip: {}.", - genv.rank, - genv.device_id, - genv.current_endpoint, + "current process rank: {}, device_id: {}, ip: {}.".format( + genv.rank, + genv.device_id, + genv.current_endpoint, + ) ) # init nccl comm @@ -231,13 +232,12 @@ def profiler(args): exe = get_executor() - exe.run(startup_program) - - # profile main - duration = 0 - eval_step = 0 - data_loader._inner_dataloader.start() try: + exe.run(startup_program) + # profile main + duration = 0 + eval_step = 0 + data_loader._inner_dataloader.start() while eval_step < args.profile_end_step: start_time = time.time() diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 0883417fc9e82..4d474569fb3eb 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -22,18 +22,17 @@ import numpy as np import paddle -import paddle.fluid.core as core -from paddle.distributed.auto_parallel.dist_attribute import ( +from paddle.fluid.framework import Variable +from paddle.fluid.io import is_belong_to_optimizer, is_parameter +from paddle.framework import core + +from .dist_attribute import ( OperatorDistributedAttribute, TensorDistributedAttribute, ) -from paddle.distributed.auto_parallel.process_group import ( - get_all_process_groups, -) -from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.fluid.framework import Variable -from paddle.fluid.io import is_belong_to_optimizer, is_parameter +from .process_group import get_all_process_groups +OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() __no_shape_var_type__ = [ @@ -1921,10 +1920,16 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): server_socket.close() -def set_recompute_ckpts(model, strategy): - from .interface import _g_recompute_idx +def is_recompute_op(op): + return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr( + 'op_namescope' + ) - if _g_recompute_idx > -1: + +def set_recompute_segments(model, losses, strategy, program): + from ..passes.auto_parallel_recompute import RecomputeState + + if not losses: return recompute = strategy.recompute @@ -1934,24 +1939,65 @@ def set_recompute_ckpts(model, strategy): # NOTE: hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here # extract ckpts by specific model + ckpts = [] if isinstance(model, paddle.nn.Layer): - if hasattr(model, "gpt") and model.__class__.__name__ in [ - 'GPTForPretraining', - 'GPTForPretrainingAuto', - ]: - exact_ckpts = model.gpt.checkpoints + if ( + hasattr(model, "gpt") + and model.__class__.__name__ + in [ + 'GPTForPretraining', + 'GPTForPretrainingAuto', + ] + and hasattr(model.gpt, "checkpoints") + ): + ckpts = model.gpt.checkpoints else: - exact_ckpts = recompute.checkpoints + ckpts = recompute.checkpoints else: - exact_ckpts = recompute.checkpoints + ckpts = recompute.checkpoints - # modify strategy - recompute.checkpoints = exact_ckpts[:] - logs = { - 'Model Class': model.__class__.__name__, - 'Applied Recompute ckpts': exact_ckpts, - } - logging.info(logs) + if not ckpts: + return + + block = program.global_block() + rc_state = RecomputeState(block, block.ops) + rc_state.build_stats() + checkpoints = rc_state.sort_checkpoints(ckpts) + + segments = [] + start_idx = -1 + pre_segment_end_idx = -1 + while start_idx + 1 < len(checkpoints): + if start_idx == -1: + ckpt_name = checkpoints[start_idx + 1] + if ckpt_name not in rc_state.var_op_deps: + start_idx += 1 + continue + op_idx_list = rc_state.var_op_deps[ckpt_name]["var_as_output_ops"] + if op_idx_list and max(op_idx_list) > 0: + segments.append([0, max(op_idx_list) + 1]) + else: + flag, min_idx, max_idx = rc_state.is_subgraph( + [checkpoints[start_idx]], [checkpoints[start_idx + 1]] + ) + if flag: + min_idx = rc_state._update_segment_start( + min_idx, pre_segment_end_idx + ) + segments.append([min_idx, max_idx + 1]) + else: + logging.debug( + "Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1 + ) + ) + start_idx += 1 + + for i, segment in enumerate(segments): + for j in range(segment[0], segment[1]): + block.ops[j]._set_attr( + 'op_namescope', "/auto_parallel/rc_" + str(i) + ) def get_input_split_info(cur_rank, var, dist_context): diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index e96cd4ec77d8f..cba613676d58d 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -226,6 +226,9 @@ def _insert_cast_op_forward( dist_context, out_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = self._block._insert_op_without_sync( idx, type="cast", @@ -236,6 +239,9 @@ def _insert_cast_op_forward( "out_dtype": out_var.dtype, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 7aed31b01ec2b..0e834343e2800 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -22,13 +22,12 @@ get_world_process_group, ) from paddle.distributed.auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, is_backward_op, is_forward_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid import unique_name from paddle.fluid.contrib.mixed_precision.fp16_utils import ( AutoMixedPrecisionLists, @@ -417,6 +416,9 @@ def _insert_forward_cast_ops( dist_context, cast_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = block._insert_op_without_sync( idx, type="cast", @@ -428,6 +430,9 @@ def _insert_forward_cast_ops( OP_ROLE_KEY: OpRole.Forward, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index af5259680e4a5..7258eca661d63 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from ..auto_parallel.dist_attribute import ( OperatorDistributedAttribute, @@ -25,8 +26,6 @@ from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.reshard import Resharder from ..auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, _get_comm_group, insert_dependencies_for_two_vars, is_gradient_clip_op, diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index c4ccb89d2f56f..1ec482e5cdfdc 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -19,12 +19,11 @@ get_world_process_group, ) from paddle.distributed.auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, is_optimize_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid import layers from paddle.fluid.framework import device_guard from paddle.framework import core diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index aa213e2432232..d99f335517a16 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -14,19 +14,8 @@ import logging -from paddle.distributed.auto_parallel.dist_attribute import ( - OperatorDistributedAttribute, -) -from paddle.distributed.auto_parallel.utils import ( - get_loss_op, - insert_dependencies_for_two_ops, - naive_set_dist_op_attr_for_program_by_mesh_and_mapping, - set_dist_op_desc_original_id, - set_var_dist_attr, -) -from paddle.fluid import core -from paddle.fluid import framework as framework -from paddle.fluid import unique_name +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole +from paddle.fluid import core, framework, unique_name from paddle.fluid.backward import ( ProgramStats, _append_grad_suffix_, @@ -35,28 +24,43 @@ _rename_arg_, ) +from ..auto_parallel.dist_attribute import OperatorDistributedAttribute +from ..auto_parallel.utils import ( + get_loss_op, + insert_dependencies_for_two_ops, + is_backward_op, + is_recompute_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_dist_op_desc_original_id, + set_var_dist_attr, +) from .pass_base import PassBase, register_pass -def _to_be_recomputed(op): - return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr( - 'op_namescope' - ) - - class RecomputeState(ProgramStats): def __init__(self, block, ops): super().__init__(block=block, ops=ops) - self._block = block - self._ops = ops - # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} - self.var_op_deps = {} - # {segment_name: op_idx} self.seg_op_deps = {} + self._checkpoints = [] + self._reserved_vars = [] + + @property + def checkpoints(self): + return self._checkpoints + + @property + def reserved_vars(self): + return self._reserved_vars - def build_stats(self): - for i, op in enumerate(self._ops): - for name in op.desc.input_arg_names(): + def is_recompute(self): + return any([is_recompute_op(op) for op in self.ops]) + + def build_states(self): + for i, op in enumerate(self.ops): + if is_backward_op(op): + break + + for name in op.input_arg_names: if name in self.var_op_deps: self.var_op_deps[name]["var_as_input_ops"].extend([i]) else: @@ -64,7 +68,7 @@ def build_stats(self): self.var_op_deps[name]["var_as_input_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [] - for name in op.desc.output_arg_names(): + for name in op.output_arg_names: if name in self.var_op_deps: self.var_op_deps[name]["var_as_output_ops"].extend([i]) else: @@ -72,7 +76,8 @@ def build_stats(self): self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = [i] - if not _to_be_recomputed(op): + if not is_recompute_op(op): + self._checkpoints.extend(op.output_arg_names) continue seg_name = op.attr('op_namescope') @@ -84,97 +89,42 @@ def build_stats(self): ), "The recompute segment's ops should be continuous" self.seg_op_deps[seg_name].extend([i]) - def get_recompute_segments( - self, checkpoints_list=None, no_recompute_segments=[] - ): - """get recompute segments and checkpoints""" + def get_recompute_segments(self, no_recompute_segments=[]): segments = [] - checkpoints = checkpoints_list or [] - - if len(checkpoints) == 0: - # the segments is marked by `auto.recompute()` api - for segment_idx in self.seg_op_deps.values(): - if len(segment_idx) == 1: - continue - segments.append([segment_idx[0], segment_idx[-1] + 1]) - checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names) - else: - # the segments is marked by `strategy.checkpoints` api - start_idx = -1 - pre_segment_end_idx = -1 - while start_idx + 1 < len(checkpoints): - if start_idx == -1: - ckpt_name = checkpoints[start_idx + 1] - if ckpt_name not in self.var_op_deps: - start_idx += 1 - continue - op_idx_list = self.var_op_deps[ckpt_name][ - "var_as_output_ops" - ] - if op_idx_list: - segments.append([0, max(op_idx_list) + 1]) - else: - flag, min_idx, max_idx = self.is_subgraph( - [checkpoints[start_idx]], [checkpoints[start_idx + 1]] - ) - if flag: - min_idx = self._update_segment_start( - min_idx, pre_segment_end_idx - ) - segments.append([min_idx, max_idx + 1]) - else: - logging.info( - "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1 - ) - ) - start_idx += 1 - - if no_recompute_segments: - for i in reversed(sorted(no_recompute_segments)): - assert i < len( - segments - ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( - i, len(segments) - ) - segments.pop(i) - - for i, (idx1, idx2) in enumerate(segments): - logging.info("recompute segment[{}]".format(i)) - logging.info( - "segment start op: [{}]: [{}] [{}]".format( - self._ops[idx1].desc.type(), - self._ops[idx1].desc.input_arg_names(), - self._ops[idx1].desc.output_arg_names(), - ) - ) - logging.info( - "segment end op: [{}]: [{}] [{}]".format( - self._ops[idx2 - 1].desc.type(), - self._ops[idx2 - 1].desc.input_arg_names(), - self._ops[idx2 - 1].desc.output_arg_names(), - ) + for segment_idx in self.seg_op_deps.values(): + if len(segment_idx) == 1: + continue + segments.append([segment_idx[0], segment_idx[-1] + 1]) + self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names) + + for i in reversed(sorted(no_recompute_segments)): + assert i < len( + segments + ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( + i, len(segments) ) + segments.pop(i) - return segments, checkpoints - - def is_recompute(self): - return any([_to_be_recomputed(op) for op in self._ops]) + return segments def modify_forward_desc_for_recompute(self, dist_context): """ If program's foward part has 'dropout' op, this function will insert a seed op before it to guarantee that two dropout op have the same outputs. """ - op_types = [op.desc.type() for op in self._ops] + op_types = [op.type for op in self.ops] if "dropout" not in op_types: return op_idx = 0 - while op_idx < len(self._ops): - cur_op = self._ops[op_idx] + while op_idx < len(self.ops): + cur_op = self.ops[op_idx] if "grad" in cur_op.type: break + if cur_op.type == "seed": + self._reserved_vars.extend(cur_op.output_arg_names) + op_idx += 1 + continue if cur_op.type != "dropout": op_idx += 1 continue @@ -188,7 +138,8 @@ def modify_forward_desc_for_recompute(self, dist_context): var_unique_name = unique_name.generate_with_ignorable_key( ".".join([op_unique_name, 'tmp']) ) - seed_var = self._block.create_var( + self._reserved_vars.append(var_unique_name) + seed_var = self.block.create_var( name=var_unique_name, dtype='int32', type=core.VarDesc.VarType.LOD_TENSOR, @@ -209,7 +160,7 @@ def modify_forward_desc_for_recompute(self, dist_context): else int(cur_op.attr("seed")) ) # TODO add dependency for seed op to ensure it be issued just before recompute. - seed_op = self._block._insert_op_without_sync( + seed_op = self.block._insert_op_without_sync( index=cur_op.idx, type="seed", inputs={}, @@ -223,7 +174,7 @@ def modify_forward_desc_for_recompute(self, dist_context): ) # modify dropout op's desc - self._ops.insert(op_idx, seed_op) + self.ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) cur_op._remove_attr("fix_seed") cur_op._remove_attr("seed") @@ -232,7 +183,7 @@ def modify_forward_desc_for_recompute(self, dist_context): ) op_idx += 2 - self._block._sync_with_cpp() + self.block._sync_with_cpp() def _find_op_index(block, cur_op): @@ -242,7 +193,7 @@ def _find_op_index(block, cur_op): return -1 -def _get_stop_gradients(program, no_grad_set): +def _get_stop_gradients(program, no_grad_set=None): """get no grad var""" if no_grad_set is None: no_grad_set = set() @@ -260,16 +211,15 @@ def _get_stop_gradients(program, no_grad_set): def _add_needed_descs_to_block( - descs, block, main_block, in_memory_vars, dist_context + descs, block, main_block, vars_should_be_hold, dist_context ): """ Get the recomputed ops which will insert the backward part """ if len(descs) == 0: return [] + result_descs = [] - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - backward = core.op_proto_and_checker_maker.OpRole.Backward for desc in descs: if isinstance(desc, framework.Operator): desc = desc.desc @@ -279,22 +229,29 @@ def _add_needed_descs_to_block( for name in desc.output_arg_names(): if main_block.has_var(name) and main_block.var(name).persistable: continue - if name not in in_memory_vars: + if name not in vars_should_be_hold: is_needed = True if is_needed: new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) set_dist_op_desc_original_id(new_op_desc, desc, dist_context) - new_op_desc._set_attr(op_role_attr_name, backward) + new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) result_descs.append(new_op_desc) return result_descs +def _find_op_path(main_program, loss, no_grad_set=None): + no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) + op_path = _find_op_path_( + main_program.global_block(), [loss], [], no_grad_set_name + ) + return op_path + + @register_pass("auto_parallel_recompute") class RecomputePass(PassBase): def __init__(self): super().__init__() - self.set_attr("checkpoints", None) self.set_attr("loss", None) self.set_attr("dist_context", None) self.set_attr("no_grad_set", None) @@ -311,49 +268,64 @@ def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, context): - checkpoints = self.get_attr("checkpoints") - no_recompute_segments = self.get_attr("no_recompute_segments") loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") + no_recompute_segments = self.get_attr("no_recompute_segments") self._dist_context = self.get_attr("dist_context") # 0. get op_path which is related to loss main_block = main_program.global_block() - no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) - op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) + op_path = _find_op_path(main_program, loss, no_grad_set) # 1. build recompute state rc_state = RecomputeState(main_block, op_path) - if not rc_state.is_recompute() and not checkpoints: + if not rc_state.is_recompute(): return # 2. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) - rc_state.build_stats() - checkpoints = rc_state.sort_checkpoints(checkpoints or []) - segments, checkpoints = rc_state.get_recompute_segments( - checkpoints, no_recompute_segments - ) - if segments == [] or checkpoints == []: + rc_state.build_states() + segments = rc_state.get_recompute_segments(no_recompute_segments) + if segments == []: return + for i, (idx1, idx2) in enumerate(segments): + logging.info( + "recompute segment[{}/{}]".format(i + 1, len(segments)) + ) + logging.info( + "segment start op: [{}]: [{}] [{}]".format( + rc_state.ops[idx1].type, + rc_state.ops[idx1].input_arg_names, + rc_state.ops[idx1].output_arg_names, + ) + ) + logging.info( + "segment end op: [{}]: [{}] [{}]".format( + rc_state.ops[idx2 - 1].type, + rc_state.ops[idx2 - 1].input_arg_names, + rc_state.ops[idx2 - 1].output_arg_names, + ) + ) + # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) ) - cross_vars = set(vars_should_be_hold) - set(checkpoints) + cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints) logging.info( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars ) ) - vars_should_be_hold.extend(rc_state.get_reserved_vars()) + vars_should_be_hold.extend(rc_state.reserved_vars) vars_should_be_hold.extend(rc_state.get_input_nodes()) - vars_should_be_hold = list(set(vars_should_be_hold)) - vars_in_memory = vars_should_be_hold + checkpoints + vars_should_be_hold = list( + set(vars_should_be_hold) | set(rc_state.checkpoints) + ) # 4. get the fwd ops desc to be recomputed. var_name_dict = {} # varname --> varname.subprog_XXX @@ -364,20 +336,23 @@ def _apply_single_impl(self, main_program, startup_program, context): var_suffix = ".subprog_%d" % i for op in fwd_ops: input_and_output_names = [] - input_and_output_names.extend(op.desc.input_arg_names()) - input_and_output_names.extend(op.desc.output_arg_names()) + input_and_output_names.extend(op.input_arg_names) + input_and_output_names.extend(op.output_arg_names) + cur_op_dist_attr = ( self._dist_context.get_op_dist_attr_for_program(op) ) assert cur_op_dist_attr is not None + for name in input_and_output_names: - if main_block.var(name).persistable or name in checkpoints: - continue - if name in vars_should_be_hold: + if ( + main_block.var(name).persistable + or name in vars_should_be_hold + ): continue if name not in var_name_dict: ref_process_mesh = cur_op_dist_attr.process_mesh - if name in op.desc.input_arg_names(): + if name in op.input_arg_names: ref_dims_mapping = ( cur_op_dist_attr.get_input_dims_mapping(name) ) @@ -385,6 +360,7 @@ def _apply_single_impl(self, main_program, startup_program, context): ref_dims_mapping = ( cur_op_dist_attr.get_output_dims_mapping(name) ) + # record recomputed var's old_name and new_name (old_name.subprog_XXX) # create new var with new name var_name_dict[name] = name + var_suffix @@ -409,7 +385,7 @@ def _apply_single_impl(self, main_program, startup_program, context): fwd_ops, buffer_block, main_block, - vars_in_memory, + vars_should_be_hold, self._dist_context, ) # rename recomputed ops' input and output var name @@ -437,15 +413,15 @@ def _apply_single_impl(self, main_program, startup_program, context): grad_op._remove_attr("fix_seed") grad_op._remove_attr("seed") - # rename grad op's var_name which is not in 'vars_in_memory' - for key in var_name_dict: - if ( - key - not in grad_op.input_arg_names + grad_op.output_arg_names - ): + input_and_output_names = [] + input_and_output_names.extend(grad_op.input_arg_names) + input_and_output_names.extend(grad_op.output_arg_names) + + for varname in var_name_dict: + if varname not in input_and_output_names: continue self.reset_op_dist_attr(grad_op, var_name_dict) - _rename_arg_([grad_op.desc], key, var_name_dict[key]) + _rename_arg_([grad_op.desc], varname, var_name_dict[varname]) # insert recomputed ops original_id = grad_op.desc.original_id() @@ -504,13 +480,13 @@ def _apply_single_impl(self, main_program, startup_program, context): def reset_op_dist_attr(self, op, var_name_dict): op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr is not None - for input in op.desc.input_arg_names(): + for input in op.input_arg_names: if input in var_name_dict.keys(): in_dist_attr = op_dist_attr.get_input_dist_attr(input) op_dist_attr.set_input_dist_attr( var_name_dict[input], in_dist_attr ) - for output in op.desc.output_arg_names(): + for output in op.output_arg_names: if output in var_name_dict.keys(): out_dist_attr = op_dist_attr.get_output_dist_attr(output) op_dist_attr.set_output_dist_attr( diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index d13e9b69b578a..21c0f88438ad8 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -74,6 +74,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) py_test_modules(test_selective_recompute MODULES test_selective_recompute) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) + py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) + set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index b77d42653abdb..35bf1a323d15c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -28,12 +28,9 @@ GPTPretrainingCriterion, ) -sequence_len = 512 -vocab_size = 1000 - class FakeDataset(paddle.io.Dataset): - def __init__(self, num_samples): + def __init__(self, num_samples, vocab_size=1000, sequence_len=512): self.num_samples = num_samples self.sequence_len = sequence_len self.vocab_size = vocab_size @@ -57,7 +54,7 @@ def __len__(self): return self.num_samples -def create_data_holder(batch_size): +def create_data_holder(batch_size, vocab_size=1000, sequence_len=512): tokens = paddle.static.InputSpec( name="tokens", shape=[batch_size, sequence_len], dtype='int64' ) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py index 10005008cdbe5..dfb554ac722d1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -98,7 +98,7 @@ def train(fetch): tuning.profile_start_step = 1 tuning.profile_end_step = 5 tuning.run_after_tuning = True - tuning.verbose = True + tuning.debug = True dataset = MyDataset(batch_num * batch_size) engine = auto.Engine( diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 8649c0f8dffcd..529d1d5f6255d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -24,7 +24,7 @@ def test_default_config(self): recompute = strategy.recompute self.assertEqual(recompute.enable, False) - self.assertIsNone(recompute.checkpoints) + self.assertEqual(recompute.checkpoints, []) amp = strategy.amp self.assertEqual(amp.enable, False) @@ -66,12 +66,10 @@ def test_default_config(self): tuning = strategy.tuning self.assertEqual(tuning.enable, False) - self.assertEqual(tuning.batch_size, 1) - self.assertIsNone(tuning.dataset) self.assertEqual(tuning.profile_start_step, 1) self.assertEqual(tuning.profile_end_step, 1) self.assertEqual(tuning.run_after_tuning, True) - self.assertEqual(tuning.verbose, True) + self.assertEqual(tuning.debug, False) def test_modify_config(self): strategy = auto.Strategy() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py new file mode 100644 index 0000000000000..a2a7deee6d216 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from get_gpt_model import FakeDataset + +import paddle +from paddle.distributed.fleet import auto + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + + +def generate_model(): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + + gpt = GPTModel( + vocab_size=50304, + hidden_size=1024, + num_hidden_layers=14, + num_attention_heads=16, + intermediate_size=1024 * 4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + use_new_recompute=True, + recompute_granularity="full", + ) + model = GPTForPretraining( + gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02 + ) + criterion = GPTPretrainingCriterion() + return model, criterion + + +def apply_pass(): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + recompute = strategy.recompute + recompute.enable = True + recompute.enable_tuning = True + + tuning = strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 2 + tuning.run_after_tuning = True + tuning.verbose = True + return strategy + + +class TestRecomputePassTuning(unittest.TestCase): + def setUp(self): + + self.batch_size = 8 + self.batch_num = 200 + self.dataset = FakeDataset( + self.batch_size * self.batch_num, + vocab_size=50304, + sequence_len=1024, + ) + + def test_recompute_pass(self): + + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(0.2) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model() + + engine = auto.Engine(model, loss, opt, strategy=strategy) + engine._tune(self.dataset, 3, batch_size=self.batch_size) + + assert ( + len( + engine._dist_contexts[ + 'train' + ].strategy.recompute.no_recompute_segments + ) + > 0 + ) + + +if __name__ == "__main__": + unittest.main()