From 172f42c512e1bf32554ef910fe82f07916b4d4af Mon Sep 17 00:00:00 2001 From: tju_skywalker <929019882@qq.com> Date: Wed, 6 Sep 2023 04:47:48 +0800 Subject: [PATCH] save space when converting hf model to megatron model. (#25950) * fix convert megatron model too large * fix convert megatron model too large --- .../checkpoint_reshaping_and_interoperability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py b/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py index afafd4d7e4d107..b535e599ad6ca4 100644 --- a/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py +++ b/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py @@ -737,7 +737,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): word_emb_dict = get_element_from_dict_by_path( output_state_dict[i], "model.language_model.embedding.word_embeddings" ) - word_emb_dict["weight"] = out_word_embed[i] + word_emb_dict["weight"] = out_word_embed[i].clone() # Transformer layers print("converting transformer layers") @@ -845,7 +845,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): for i in range(args.target_tensor_model_parallel_size): params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") params_dict[layer_name] = ( - params[i] if (op_name + "." + weight_or_bias in tensor_parallel_params) else params + params[i].clone() if (op_name + "." + weight_or_bias in tensor_parallel_params) else params ) if pp_rank == args.target_pipeline_model_parallel_size - 1: @@ -860,7 +860,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): # add the LM head for i in range(args.target_tensor_model_parallel_size): params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head") - params_dict["weight"] = out_word_embed[i] + params_dict["weight"] = out_word_embed[i].clone() # saving the state dict as per the tp_rank and pp_rank for tp_rank in range(args.target_tensor_model_parallel_size):