Skip to content

Commit

Permalink
[AutoParallel] recompute tuning (PaddlePaddle#48608)
Browse files Browse the repository at this point in the history
* [AutoParallel] recompute tuning

* fix conflict

* update comment

* bug fix

* update rc algo

* tiny fix

* fix clear process_group

* remove comment

* update segment print

* fix import OpRole

* adapt amp pass and grad_clip pass for opt_tuner

* update tuning config

* fix import

* annotate recompute info on ops and upgrade recompute pass

* add op_namescope for seed op

* record reserved vars

* fix recompute var's dist_attr

* fix strategy unittest

* adapt for fp16

* update unittest

* revert copy opt

* update unittest

* rename set_recompute_segments

* fix unittest
  • Loading branch information
zhaoyinglia authored Dec 14, 2022
1 parent d2d3908 commit 1314aa8
Show file tree
Hide file tree
Showing 18 changed files with 495 additions and 249 deletions.
6 changes: 2 additions & 4 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def set_field_default_config(category, field, default_value):
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "checkpoints", [])
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "enable_tuning", False)

Expand Down Expand Up @@ -113,12 +113,10 @@ def set_field_default_config(category, field, default_value):
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
set_field_default_config(TUNING, "debug", False)

#########################################
# dataset configuration
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,9 @@ def _build(self, mode):
if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)

auto_utils.set_recompute_ckpts(self._model, self._strategy)
auto_utils.set_recompute_segments(
self._model, self._losses, self._strategy, serial_main_prog
)
self._dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
Expand Down Expand Up @@ -649,7 +651,6 @@ def _optimization_tuning(self, mode, dataset, batch_size):
from .tuner.optimization_tuner import OptimizationTuner

self._optimization_tuner = OptimizationTuner(
self._tuning.to_dict(),
self._dist_contexts[mode],
dataset,
self._inputs_spec,
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __deepcopy__(self, memo):
setattr(result, k, copy.deepcopy(v, memo))
return result

def get(self, k, d=None):
result_dict = self.to_dict()
return result_dict.get(k, d)


class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
Expand Down
95 changes: 92 additions & 3 deletions python/paddle/distributed/auto_parallel/tuner/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from abc import ABC, abstractmethod

from ..utils import get_logger
from ..utils import get_logger, is_recompute_op
from .trial import OptimizationTunerTrial as Trial
from .trial import TrialStatus

Expand Down Expand Up @@ -54,7 +54,7 @@ def changed_configs(self):
def collect_model_info(self, main_prog, startup_prog):
"""
Collect the model static info (from programs) that could be used to
pruning candidate trials and saving tuning time.For instance,
pruning candidate trials and saving tuning time. For instance,
model info like number of model parameters and activation memory could be
used to prune candidated trial and decide the next trial.
"""
Expand Down Expand Up @@ -116,7 +116,7 @@ def _init_spaces(self):
self._max_stage = 3
self._trial_idx = 0

stage_range = self._config.sharding.to_dict().get("tuning_range", None)
stage_range = self._config.sharding.get("tuning_range", None)
if stage_range:
assert set(stage_range).issubset(
set([0, 1, 2, 3])
Expand Down Expand Up @@ -157,3 +157,92 @@ def update(self, results):
)
else:
self._trial_idx += 1


@register_algor("recompute")
class ReccomputeCheckpointAlgorithm(AlgorithmBase):
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["recompute"]

def collect_model_info(self, main_prog, startup_prog):
segments = []
for op in main_prog.global_block().ops:
if not is_recompute_op(op):
continue

seg_name = op.attr('op_namescope')
if seg_name not in segments:
segments.append(seg_name)

self._total_num_trial = len(segments)
self._tuning_segments = list(range(len(segments)))
self._trail_left = 0
self._trail_right = len(segments) - 1
self._trial_idx = int(0 + (len(segments) - 1) / 2)

def _init_spaces(self):
self._recompute_mode = "all"

def next_trial(self):
if self._trial_idx < self._total_num_trial:
if self._recompute_mode == "all":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
name = "trial-recompute-all-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "none":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.enable = False
name = "trial-recompute-none-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "part":
new_no_recompute = self._tuning_segments[: self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.no_recompute_segments.extend(new_no_recompute)
name = "trial-recompute-part-segments-idx{}".format(
self._trial_idx
)
return Trial(new_strategy, name, self.changed_configs)
else:
return Trial(None, None, None, status=TrialStatus.STOPPED)

def update(self, results):

et = results.get("ErrorType", None)
if self._recompute_mode == "all":
if et and et == "ResourceExhaustedError":
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute all candidate segments is failed with OOM, please reduce model size or batch size."
)
else:
self._recompute_mode = "none"
elif self._recompute_mode == "none":
if et and et == "ResourceExhaustedError":
self._recompute_mode = "part"
else:
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute is unnecessary for this model size, which will reduce the Throughtput."
)
else:
if self._trail_left >= self._trail_right:
self._trial_idx = self._total_num_trial
elif et and et == "ResourceExhaustedError":
self._trail_left = self._trail_left
self._trail_right = self._trial_idx - 1
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
else:
self._trail_left = self._trial_idx + 1
self._trail_right = self._trail_right
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
50 changes: 21 additions & 29 deletions python/paddle/distributed/auto_parallel/tuner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ class TuningConfig:
tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
"""

def __init__(self, user_config, strategy):
def __init__(self, strategy):

if not isinstance(strategy, Strategy):
raise TypeError("'strategy' must be object of class `Strategy`.")

if not user_config:
user_config = {}

self._tuning_passes_name = set()
self._dist_strategy = copy.deepcopy(strategy)
self._mode = None
Expand All @@ -48,9 +45,9 @@ def __init__(self, user_config, strategy):
self._project_dir = None
self._max_num_trial = None
self._early_stop = None
self._verbose = None
self._debug = None

self._initialize(user_config)
self._initialize()

@property
def mode(self):
Expand Down Expand Up @@ -81,29 +78,25 @@ def early_stop(self):
return self._early_stop

@property
def verbose(self):
return self._verbose
def debug(self):
return self._debug

@property
def dist_strategy(self):
return self._dist_strategy

# initialize config with user define value or default value
def _initialize(self, user_config):

self._mode = user_config.get("mode", "PROFILE")

self._profile_start_step = user_config.get("profile_start_step", 10)

self._profile_end_step = user_config.get("profile_end_step", 30)

self._max_num_trial = user_config.get("max_num_trial", 50)

self._early_stop = user_config.get("early_stop", None)
def _initialize(self):
tuning_strategy = self._dist_strategy.tuning

self._verbose = user_config.get("verbose", False)
self._mode = tuning_strategy.get("mode", "PROFILE")
self._profile_start_step = tuning_strategy.get("profile_start_step", 10)
self._profile_end_step = tuning_strategy.get("profile_end_step", 30)
self._max_num_trial = tuning_strategy.get("max_num_trial", 50)
self._early_stop = tuning_strategy.get("early_stop", None)
self._debug = tuning_strategy.get("debug", False)

project_dir = user_config.get("project_dir", None)
project_dir = tuning_strategy.get("project_dir", None)
if not project_dir:
project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
self._project_dir = project_dir
Expand All @@ -116,15 +109,14 @@ def _initialize(self, user_config):
# TODO distinguish different args of each passes
self._tuning_passes_name.add(p)

config_name = p
p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict
p_strategy = getattr(self._dist_strategy, p)
self.__dict__[p] = p_strategy

# TODO verify the user defined configs
user_config_for_pass = user_config.get(p, None)
if user_config_for_pass:
for k, v in user_config_for_pass.items():
self.__dict__[config_name][k] = v
# # TODO verify the user defined configs
# tuning_config_for_pass = tuning_strategy.get(p, None)
# if tuning_config_for_pass:
# for k, v in tuning_config_for_pass.items():
# self.__dict__[p][k] = v

# (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
def __getattr__(self, item):
Expand Down
Loading

0 comments on commit 1314aa8

Please sign in to comment.