Skip to content

Commit

Permalink
[AutoParallel] Fix ckpt oom paddlenlp (#9507)
Browse files Browse the repository at this point in the history
* Update ckpt_converter.py
  • Loading branch information
Xing-lil authored Dec 9, 2024
1 parent b894201 commit 5b54d71
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions paddlenlp/trainer/utils/ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def load_from_hybrid_parallel_checkpoint(self):
assert self.optimizer_state_with_master_weights
model_params = {}
for state_name, state_value in self.auto_parallel_state_dict.items():
self.auto_parallel_state_dict[state_name] = state_value.cuda()
if state_name in self.parameter_to_structured_name.values():
model_params[state_name] = state_value
for param_name in model_params.keys():
Expand Down Expand Up @@ -143,7 +144,7 @@ def load_from_hybrid_parallel_checkpoint(self):
self.auto_parallel_state_dict[master_weight] = tmp_tensor

logger.info("Calling _load_state_dict to load the required weights.")
_load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata])
_load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata], offload=True)
logger.info("Calling _load_state_dict completed, restored the required weights.")

# In this scenario, the data type of the model state is bfloat16.
Expand All @@ -157,7 +158,7 @@ def load_from_hybrid_parallel_checkpoint(self):
self.auto_parallel_state_dict.pop(master_weight_name)
else:
logger.info("Calling _load_state_dict to load the required weights.")
_load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata])
_load_state_dict(self.auto_parallel_state_dict, source_state_dict, [metadata], offload=True)
logger.info("Calling _load_state_dict completed, restored the required weights.")
logger.info("Successfully loaded hybrid_parallel checkpoint!")

Expand Down Expand Up @@ -275,7 +276,9 @@ def gen_metadata_and_prepare_source_state_dict(self):

# merge sharding
logger.info("First call _load_state_dict to stitch back the tensors split by sharding1 v2.")
_load_state_dict(optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding])
_load_state_dict(
optimizer_state_dict, source_state_dict_for_merge_sharding, [metadata_for_merge_sharding], offload=True
)
logger.info("Completed the call _load_state_dict, concating back the tensors split by sharding.")

# Reshape
Expand Down Expand Up @@ -437,7 +440,9 @@ def gen_metadata_and_prepare_source_state_dict(self):
target_state_dict[key + ".beta1_pow_acc"] = paddle.zeros((1,), "float32")
target_state_dict[key + ".beta2_pow_acc"] = paddle.zeros((1,), "float32")

_load_state_dict(target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding])
_load_state_dict(
target_state_dict, self.cur_rank_loaded_state_dict, [metadata_for_merge_sharding], offload=True
)

# Reshape
for item in cur_rank_merger_model_params:
Expand Down Expand Up @@ -535,7 +540,7 @@ def load_state_dict_and_rename(self):
self.cur_rank_loaded_state_dict = {}

for file in need_read_files:
self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file))
self.cur_rank_loaded_state_dict[file] = paddle.load(os.path.join(self.path, file), return_numpy=True)

self.optimizer_state_with_master_weights = False

Expand All @@ -554,7 +559,8 @@ def load_state_dict_and_rename(self):
memory_size = 0
for file, state_dict in self.cur_rank_loaded_state_dict.items():
for k, v in state_dict.items():
memory_size += v.numel().numpy() * v.element_size()
memory_size += v.size * v.itemsize

memory_size = memory_size / 2**20
logger.debug(
f"The current rank has finished loading the checkpoint file and has allocated {memory_size} MB of GPU memory."
Expand Down Expand Up @@ -767,7 +773,7 @@ def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_inf
[
{"tp_rank": tp_rank, "sharding_rank": sharding_rank},
state_value.shape,
str(state_value.dtype).split(".")[1],
str(state_value.dtype),
file,
]
]
Expand All @@ -776,7 +782,7 @@ def get_sharded_tensor_infos(self, file, state_dict, cur_rank_sharded_tensor_inf
[
{"tp_rank": tp_rank, "sharding_rank": sharding_rank},
state_value.shape,
str(state_value.dtype).split(".")[1],
str(state_value.dtype),
file,
]
)
Expand Down

0 comments on commit 5b54d71

Please sign in to comment.