From 72ef654a9f82ddb8ec87af6a1f7884c22c8da53f Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Wed, 3 Jan 2024 18:01:55 -0800 Subject: [PATCH] Fixing wrapper and moving it to base class (#8055) * Fixing wrapper and moving it to base class * Fixing wrapper and moving it to base class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update megatron_bert_model.py Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> * Dummy changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing bugs for bert --------- Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister --- .../language_modeling/megatron_base_model.py | 40 ++++++++++++++++++- .../language_modeling/megatron_bert_model.py | 38 +----------------- .../language_modeling/megatron_gpt_model.py | 33 --------------- 3 files changed, 40 insertions(+), 71 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index c3046e4f01f4..731cd875ae07 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -231,6 +231,41 @@ def setup_transformer_engine_tp_groups(self): tp_group = parallel_state.get_tensor_model_parallel_group() child.set_tensor_parallel_group(tp_group) + def _wrap_model_for_O2(self): + """ Wraps self.model in a float16 wrapper if the model is using megatron amp O2. + Args: + model: The model to wrap. Can be a list of modules or a single module. + Returns: + The wrapped model. Returns a list of wrapped modules or a single wrapped module. + """ + is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False) + + Float16Wrapper = MCoreFloat16Module if is_mcore_model else Float16Module + + nemo_args = {'config': self.model_parallel_config, 'precision': self.cfg.precision} + + if type(self).__name__ == 'MegatronGPTModel': + nemo_args['share_token_embeddings'] = self.cfg.get('share_embeddings_and_output_weights', True) + + mcore_args = { + 'config': self.transformer_config, + } + + args = mcore_args if is_mcore_model else nemo_args + + # Model wrapper to convert both model and inputs to half precision + if isinstance(self.model, list): + converted_model = [] + for module in self.model: + args['module'] = module + converted_model.append(Float16Wrapper(**args)) + self.model = converted_model + else: + args['module'] = self.model + self.model = Float16Wrapper(**args) + + args.pop('module') + def get_model_module_list(self): if isinstance(self.model, list): return [ @@ -826,6 +861,7 @@ def is_data_parallel_rank_zero(self): def _get_total_params_across_model_parallel_groups_gpt_bert(self, model): """Returns the total number of parameters across all model parallel groups.""" + is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False) # log number of parameters if isinstance(model, list): num_parameters_on_device = sum( @@ -838,7 +874,7 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model): ): word_embeddings_weight = ( model[-1].module.shared_embedding_or_output_weight() - if getattr(self, 'mcore_gpt', False) + if is_mcore_model else model[-1].word_embeddings_weight() ) # substract the embedding weights on the last virtual stage @@ -853,7 +889,7 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model): ): word_embeddings_weight = ( model.module.shared_embedding_or_output_weight() - if getattr(self, 'mcore_gpt', False) + if is_mcore_model else model.word_embeddings_weight() ) # substract the embedding weights on the last stage diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index ceab93a30a1c..e4ae0f87d353 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -136,40 +136,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self._nsys_profile_start_step *= grad_accum_steps self._nsys_profile_end_step *= grad_accum_steps - def _wrap_model_for_O2(self): - """ Wraps self.model in a float16 wrapper if the model is using megatron amp O2. - Args: - model: The model to wrap. Can be a list of modules or a single module. - Returns: - The wrapped model. Returns a list of wrapped modules or a single wrapped module. - """ - Float16Wrapper = MCoreFloat16Module if self.mcore_bert else Float16Module - - nemo_args = { - 'config': self.model_parallel_config, - 'precision': self.cfg.precision, - } - mcore_args = { - 'config': self.transformer_config, - } - - args = mcore_args if self.mcore_bert else nemo_args - - # Model wrapper to convert both model and inputs to half precision - if isinstance(self.model, list): - converted_model = [] - for module in self.model: - if not self.mcore_bert: - args['module'] = module - converted_model.append(Float16Wrapper(**args)) - self.model = converted_model - else: - if not self.mcore_bert: - args['module'] = self.model - self.model = Float16Wrapper(**args) - - args.pop('module') - def model_provider_func(self, pre_process, post_process): cfg = self.cfg num_tokentypes = 2 if cfg.bert_binary_head else 0 @@ -990,7 +956,7 @@ def configure_optimizers(self): if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module stage_bucket = [] - layers = module.transformer.layers if self.mcore_bert else module.language_model.encoder.layers + layers = module.encoder.layers if self.mcore_bert else module.language_model.encoder.layers for layer in layers: stage_bucket.extend( p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False) @@ -1002,7 +968,7 @@ def configure_optimizers(self): for module in modules: if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module - layers = module.transformer.layers if self.mcore_bert else module.language_model.encoder.layers + layers = module.encoder.layers if self.mcore_bert else module.language_model.encoder.layers for layer in layers: buckets.append( [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ad4729e910e8..b7e5fc08b1f8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1554,36 +1554,3 @@ def build_transformer_config(self) -> TransformerConfig: setattr(transformer_config, key, value) return transformer_config - - def _wrap_model_for_O2(self): - """ Wraps self.model in a float16 wrapper if the model is using megatron amp O2. - Args: - model: The model to wrap. Can be a list of modules or a single module. - Returns: - The wrapped model. Returns a list of wrapped modules or a single wrapped module. - """ - Float16Wrapper = MCoreFloat16Module if self.mcore_gpt else Float16Module - - nemo_args = { - 'config': self.model_parallel_config, - 'precision': self.cfg.precision, - 'share_token_embeddings': self.cfg.get('share_embeddings_and_output_weights', True), - } - mcore_args = { - 'config': self.transformer_config, - } - - args = mcore_args if self.mcore_gpt else nemo_args - - # Model wrapper to convert both model and inputs to half precision - if isinstance(self.model, list): - converted_model = [] - for module in self.model: - args['module'] = module - converted_model.append(Float16Wrapper(**args)) - self.model = converted_model - else: - args['module'] = self.model - self.model = Float16Wrapper(**args) - - args.pop('module')