diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py deleted file mode 100644 index 4bd1a1ffb08e..000000000000 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ /dev/null @@ -1,2569 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import gc -import json -import multiprocessing -import os -import sys -import time -from multiprocessing import shared_memory - -import numpy as np -import paddle -import paddle.distributed as dist -from paddle.distributed import fleet -from tqdm.auto import tqdm - -try: - from paddle.base import core -except: - core = None - -from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM -from paddlenlp.trainer.argparser import strtobool -from paddlenlp.trainer.trainer_utils import ExplicitEnum -from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile -from paddlenlp.transformers.model_utils import ( - PretrainedModel, - _add_variant, - _load_state_dict_into_model, - faster_set_state_dict, - get_parameter_dtype, - load_state_dict, - unwrap_model, -) -from paddlenlp.transformers.utils import ( - device_guard, - dtype_byte_size, - get_checkpoint_shard_files, - is_safetensors_available, -) -from paddlenlp.utils.distributed import distributed_allgather, distributed_gather -from paddlenlp.utils.env import ( - LORA_WEIGHTS_NAME, - PADDLE_MASTER_WEIGHTS_INDEX_NAME, - PADDLE_MASTER_WEIGHTS_NAME, - PADDLE_OPTIMIZER_INDEX_NAME, - PADDLE_OPTIMIZER_NAME, - PADDLE_PEFT_WEIGHTS_INDEX_NAME, - PADDLE_WEIGHTS_INDEX_NAME, - PADDLE_WEIGHTS_NAME, - PAST_KEY_VALUES_FILE_NAME, - PREFIX_WEIGHTS_NAME, - SAFE_MASTER_WEIGHTS_INDEX_NAME, - SAFE_MASTER_WEIGHTS_NAME, - SAFE_OPTIMIZER_INDEX_NAME, - SAFE_OPTIMIZER_NAME, - SAFE_PEFT_WEIGHTS_INDEX_NAME, - SAFE_PEFT_WEIGHTS_NAME, - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, -) -from paddlenlp.utils.log import logger -from paddlenlp.utils.nested import nested_copy, nested_copy_place -from paddlenlp.utils.tools import get_env_device - -if is_safetensors_available(): - from safetensors.numpy import save_file as safe_save_file - - if sys.platform.startswith("win"): - from safetensors import safe_open - from safetensors.numpy import load_file - else: - from paddlenlp.utils.safetensors import fast_safe_open as safe_open - from paddlenlp.utils.safetensors import fast_load_file as load_file - -from .shared_memory_utils import ( - _read_state_dict_from_shm, - _traverse_copy_to_shm, - create_meta_dict, -) - -FP32_MASTER = "fp32_master_0" -optimizer_scalar_name = [ - "beta1_pow_acc_0", - "beta2_pow_acc_0", -] -optimizer_non_scaler_name = [ - "moment1_0", - "moment2_0", - "velocity_0", -] # to be added - - -DEST_PLACE = paddle.CPUPlace() -if paddle.device.is_compiled_with_cuda(): - DEST_PLACE = paddle.CUDAPinnedPlace() - - -class UnifiedCheckpointOption(ExplicitEnum): - """ - "- skip_save_model_weight: do not save model weights when the masters weight exist\n" - "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" - " 2. if master weights does not exist, convert model weights to master weights when needed\n" - "- async_save: enable asynchronous saving checkpoints to disk\n" - "- enable_all_options: enable all optimization configurations\n" - """ - - SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight" - MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible" - ASYNC_SAVE = "async_save" - IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" - - -class UnifiedCheckpointHandler: - def __init__(self, args): - self.args = args - self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - - # Mainly for asynchronous saving. - self._shm_model_weight = None - self._shm_master_weight = None - self._shm_optimizer_weight = None - self._meta_dict_model = None - self._meta_dict_master_weight = None - self._meta_dict_optim = None - self._process_model_weight = None - self._process_master_weight = None - self._process_optimizer_weight = None - self._lock = None - self._shared_save_model_flag = None - self._shared_save_master_weight_flag = None - self._shared_save_optimizer_flag = None - - if "async_save" in self.args.unified_checkpoint_config: - self._lock = multiprocessing.Lock() - self._shared_save_model_path = multiprocessing.Array("c", 100000) - self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) - self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) - self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_model_flag = multiprocessing.Array("i", 1) - self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) - self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) - - def _file_save_async_or_sync( - self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" - ): - if is_sync: - for k in list(state_dict.keys()): - if isinstance(state_dict[k], paddle.Tensor): - state_dict[k] = state_dict.pop(k).cpu().numpy() - safe_save_file(state_dict, path, metadata={"format": "np"}) - else: - if state_dict_type == "model_weight": - if self._shm_model_weight is None: - self._meta_dict_model, buffer_size = create_meta_dict(state_dict) - self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_model_weight - meta_dict = self._meta_dict_model - shared_save_flag = self._shared_save_model_flag - shared_save_path = self._shared_save_model_path - shared_save_signal_path = self._shared_save_model_signal_path - if self._process_model_weight is None: - self._process_model_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_model_weight.name, - self._shared_save_model_flag, - self._shared_save_model_path, - self._shared_save_model_signal_path, - self._lock, - state_dict_type, - self.global_rank, - ), - ) - self._process_model_weight.start() - process = self._process_model_weight - elif state_dict_type == "master_weight": - if self._shm_master_weight is None: - self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) - self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_master_weight - meta_dict = self._meta_dict_master_weight - shared_save_flag = self._shared_save_master_weight_flag - shared_save_path = self._shared_save_master_weight_path - shared_save_signal_path = self._shared_save_master_weight_signal_path - if self._process_master_weight is None: - self._process_master_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_master_weight.name, - self._shared_save_master_weight_flag, - self._shared_save_master_weight_path, - self._shared_save_master_weight_signal_path, - self._lock, - "model_weight" - if "skip_save_model_weight" in self.args.unified_checkpoint_config - else state_dict_type, - self.global_rank, - ), - ) - self._process_master_weight.start() - process = self._process_master_weight - elif state_dict_type == "optimizer_weight": - if self._shm_optimizer_weight is None: - self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) - self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_optimizer_weight - meta_dict = self._meta_dict_optim - shared_save_flag = self._shared_save_optimizer_flag - shared_save_path = self._shared_save_optimizer_path - shared_save_signal_path = self._shared_save_optimizer_signal_path - if self._process_optimizer_weight is None: - self._process_optimizer_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_optimizer_weight.name, - self._shared_save_optimizer_flag, - self._shared_save_optimizer_path, - self._shared_save_optimizer_signal_path, - self._lock, - state_dict_type, - self.global_rank, - ), - ) - self._process_optimizer_weight.start() - process = self._process_optimizer_weight - - while True: # wait until no process is saving. - flag_value = shared_save_flag[0] - if flag_value == 0: - break - if not process.is_alive(): - raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") - time.sleep(0.5) - logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") - # only save model weight or save master weight, we enter this loop. - self._reset_and_update(shared_save_path, path) - self._reset_and_update(shared_save_signal_path, signal_path) - _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) - with self._lock: - shared_save_flag[0] = 1 - - def _save_file_async_in_process( - self, - meta_dict, - shm_name, - shared_save_flag, - shared_save_path, - shared_save_signal_path, - lock, - state_dict_type, - global_rank, - ): - shm = shared_memory.SharedMemory(name=shm_name) - while True: - flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value. - if flag_value == -1: # stop process - break - if flag_value == 0: # nothing to save - continue - if flag_value == 1: # need to save - path = shared_save_path[:].decode("utf-8").rstrip("\x00") - signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") - logger.info(f"Start to async save {path}") - state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array - safe_save_file(state_dict, path, {"format": "np"}) - del state_dict - saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") - paddle.save(global_rank, saved_signal_path) - with lock: - shared_save_flag[0] = 0 - time.sleep(0.5) - shm.close() - - def _reset_and_update(self, shared_array, new_value): - # clear array - for i in range(len(shared_array)): - shared_array[i] = b"\0" - # update array - encoded_value = new_value.encode("utf-8") - shared_array[: len(encoded_value)] = encoded_value - - def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): - """save unified checkpoint - - Args: - model (PretrainedModel): model to save - output_dir (str): save dir - safe_serialization (bool, optional): use safetensors. Defaults to False. - - Raises: - ValueError: if model is not an instance of `PretrainedModel` and the model cannot be saved - """ - if isinstance(model, PretrainedModel): - model_to_save = model - elif isinstance(unwrap_model(model), PretrainedModel): - model_to_save = unwrap_model(model) - elif isinstance(model, PrefixModelForCausalLM) or isinstance(model, LoRAModel): - model_to_save = model - else: - raise ValueError("Unified checkpoint only supports PretrainedModel, LoRAModel and PrefixModelForCausalLM!") - - # Under non distributed environment. - if paddle.distributed.get_world_size() <= 1: - self.save_single_card_checkpoint(model_to_save, output_dir) - return - - skip_save_model_weight = False - if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in self.args.unified_checkpoint_config: - if is_need_master_weight(optimizer, is_fp16_or_bp16=(self.args.fp16 or self.args.bf16)): - logger.info( - f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save." - " The master weight will be loaded as model weights for next resumption." - ) - # not save model weight, load from master weight - skip_save_model_weight = True - - save_directory = output_dir - os.makedirs(save_directory, exist_ok=True) - if signal_dir is not None: - os.makedirs(signal_dir, exist_ok=True) # only for async save - - # save model weights - if not skip_save_model_weight: - state_dict, shard_file, sharded_index = unified_checkpoint_into_shards( - self.args, model_to_save, safe_serialization=True - ) - is_sync_save = True - if "async_save" in self.args.unified_checkpoint_config: - is_sync_save = False - self._file_save_async_or_sync( - state_dict, - path=os.path.join(save_directory, shard_file), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="model_weight", - ) - if sharded_index is not None: - if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): - index_name = SAFE_PEFT_WEIGHTS_INDEX_NAME - else: - index_name = SAFE_WEIGHTS_INDEX_NAME - path = os.path.join(output_dir, index_name) - - if self.args.should_save: - with open(path, "w") as f: - json.dump(sharded_index, f, indent=4) - - if self.args.should_save: - # Save prefix model past_key_values - if isinstance(model_to_save, PrefixModelForCausalLM): - save_prefix_past_key_value(model_to_save, save_directory) - model_to_save.prefix_config.save_pretrained(save_directory) - if isinstance(model_to_save, LoRAModel): - model_to_save.lora_config.save_pretrained(save_directory) - - # save the config - config_to_save = save_config(model_to_save) - # Attach architecture to the config - if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): - config_to_save.architectures = [model_to_save.model.__class__.__name__] - else: - config_to_save.architectures = [model_to_save.__class__.__name__] - if self.args.should_save: - config_to_save.save_pretrained(save_directory) - # save generation config - if model_to_save.can_generate(): - model_to_save.generation_config.save_pretrained(save_directory) - paddle.device.cuda.empty_cache() - - if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save: - world_size = paddle.distributed.get_world_size() - save_info = { - "world_size": world_size, - "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, - "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, - } - paddle.save(save_info, os.path.join(save_directory, ".saving_info")) - - def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str): - """Load potential model checkpoint - - Args: - model (PretrainedModel): Your model to load - resume_from_checkpoint (str): path of the checkpoint to load - - Returns: - None - """ - if paddle.distributed.get_world_size() <= 1: - load_single_card_checkpoint(self.args, model, resume_from_checkpoint) - return - - local_resume = check_unified_checkpoint(self.args, model, resume_from_checkpoint, safe_serialization=True) - - if not local_resume: - logger.info("Begin to dynamically load unified checkpoint!") - load_unified_checkpoint_dynamically( - self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True - ) - return - - if self.args.dataset_rank == 0 or self.args.use_expert_parallel: - load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) - - def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir): - paddle.device.cuda.empty_cache() - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict["master_weights"] - optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") - - # gather global master_weights status. - global_master_weights = reduce_master_weights_status(master_weights is not None) - if master_weights is None and global_master_weights: - master_weights = {} - - # get optimizer param mappings - static2struct_name_mappings = {} - state_dict = get_expected_state_dict(model) - for k, v in state_dict.items(): - static2struct_name_mappings[v.name] = k - - # rename optimizer param name - for key in list(optim_state_dict.keys()): - static_name, type_name = generate_base_static_name(key) - new_name = static2struct_name_mappings[static_name] + "/" + type_name - optim_state_dict[new_name] = optim_state_dict.pop(key) - if master_weights is not None: - for key in list(master_weights.keys()): - master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) - - no_sync_kname = [] - model_state_dict = get_expected_state_dict(model) - for k, v in model_state_dict.items(): - if getattr(v, "no_sync", False): - no_sync_kname.append(k) - - hcg = fleet.get_hybrid_communicate_group() - dp_group = hcg.get_data_parallel_group() - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - if self.args.use_expert_parallel: - for k in list(optim_state_dict.keys()): - model_k = k.split("/")[0] - if dp_rank > 0 and model_k not in no_sync_kname: - optim_state_dict.pop(k) - if master_weights is not None: - for k in list(master_weights.keys()): - model_k = k.split("/")[0] - if dp_rank > 0 and model_k not in no_sync_kname: - master_weights.pop(k) - - optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) - master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) - - is_sync_save = True - if "async_save" in self.args.unified_checkpoint_config: - is_sync_save = False - self._file_save_async_or_sync( - optim_state_dict, - path=os.path.join(output_dir, optimizer_name), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="optimizer_weight", - ) - self._file_save_async_or_sync( - master_weights, - path=os.path.join(output_dir, master_weights_name), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="master_weight", - ) - - def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): - # init and get optimizer LR_Scheduler - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - - optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) - master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) - optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name) - master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name) - has_master_weights = True if os.path.isfile(master_weights_path) else False - - model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - optimizer_state_dict = load_file(optimizer_path) - if has_master_weights: - master_weights = load_file(master_weights_path) - - # rename and move to paddle.Tensor - for key in list(optimizer_state_dict.keys()): - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - with device_guard(): - weight = paddle.Tensor(optimizer_state_dict.pop(key), zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - returned_optim_state_dict[key_name] = weight - returned_optim_state_dict[key_name].name = key_name - - if has_master_weights: - returned_optim_state_dict["master_weights"] = {} - for key in list(master_weights.keys()): - static_name = struct2static_name_mappings[key] - with device_guard(): - weight = paddle.Tensor(master_weights.pop(key), zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - returned_optim_state_dict["master_weights"][static_name] = weight - returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - return returned_optim_state_dict - - def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): - """save unified optimizer - - Args: - model (PretrainedModel): model used to get key mapping. - optimizer (Optimizer): optimizer to save - output_dir (str): Save directory. - signal_dir (str): Asynchronous saving signal directory. - - """ - - if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: - self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir) - return - - if paddle.distributed.get_world_size() <= 1: - self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal - return - - # Split into naive optimizer params and master weights. - results = unified_optimizer_into_shards(self.args, model, optimizer, safe_serialization=True) - master_weight_state_dict = None - if len(results) == 1: - optim_state_dict, shard_optim_file, sharded_optim_index = results[0] - else: - optim_state_dict, shard_optim_file, sharded_optim_index = results[0] - master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1] - - paddle.device.cuda.empty_cache() - - save_directory = output_dir - os.makedirs(save_directory, exist_ok=True) - if signal_dir is not None: - os.makedirs(signal_dir, exist_ok=True) - - is_sync_save = True - if "async_save" in self.args.unified_checkpoint_config: - is_sync_save = False - self._file_save_async_or_sync( - optim_state_dict, - path=os.path.join(save_directory, shard_optim_file), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="optimizer_weight", - ) - if master_weight_state_dict is not None: - self._file_save_async_or_sync( - master_weight_state_dict, - path=os.path.join(save_directory, shard_master_weight_file), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="master_weight", - ) - - if sharded_optim_index is not None: - optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME - path = os.path.join(output_dir, optimizer_index_name) - if self.args.should_save: - with open(path, "w") as f: - json.dump(sharded_optim_index, f, indent=4) - - master_weights_name = SAFE_MASTER_WEIGHTS_INDEX_NAME - if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in self.args.unified_checkpoint_config: - master_weights_name = SAFE_WEIGHTS_INDEX_NAME - master_path = os.path.join(output_dir, master_weights_name) - if master_weight_state_dict is not None: - if self.args.should_save: - with open(master_path, "w") as f: - json.dump(sharded_master_weight_index, f, indent=4) - - def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint): - """Load potential model checkpoint - - Args: - model (PretrainedModel): Your model to load - resume_from_checkpoint (str): path of the checkpoint to load - - Returns: - None - """ - - if paddle.distributed.get_world_size() <= 1: - optim_state_dict = load_single_card_optimizer(self.args, model, optimizer, resume_from_checkpoint) - return optim_state_dict - - has_merge_optimizer_safetensors = distributed_isfile( - os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME) - ) - # If not having merge optimizer, then load non-merge optimizer. - if not has_merge_optimizer_safetensors: - if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: - returned_optim_state_dict = self.load_non_merge_optimizer( - model, - optimizer, - resume_from_checkpoint, - ) - return returned_optim_state_dict - else: - return None - - local_resume = check_unified_optimizer( - self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True - ) - if not local_resume: - logger.info("Begin to dynamically load unified optimizer!") - returned_optim_state_dict = load_unified_optimizer_dynamically( - self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True - ) - return returned_optim_state_dict - - if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: - returned_optim_state_dict = load_unified_optimizer_locally( - self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True - ) - return returned_optim_state_dict - return None - - def save_single_card_checkpoint(self, model_to_save, output_dir): - """Save checkpoint for non-distributed environment.""" - - state_dict = get_expected_state_dict(model_to_save) - if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): - weight_filename = "peft_model-00001-of-00001.safetensors" - index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME - else: - weight_filename = "model-00001-of-00001.safetensors" - index_filename = SAFE_WEIGHTS_INDEX_NAME - # get index json - index_weight_file = {} - total_size = 0 - for key, weight in state_dict.items(): - index_weight_file[key] = weight_filename - total_size += weight.numel().item() * dtype_byte_size(weight.dtype) - sharded_index_json = {} - sharded_index_json["metadata"] = {"total_size": total_size} - sharded_index_json["weight_map"] = index_weight_file - if isinstance(model_to_save, LoRAModel): - sharded_index_json["type"] = "lora" - elif isinstance(model_to_save, PrefixModelForCausalLM): - sharded_index_json["type"] = "ptuning" - - os.makedirs(output_dir, exist_ok=True) - path = os.path.join(output_dir, index_filename) - with open(path, "w") as f: - json.dump(sharded_index_json, f, indent=4) - - # save checkpoint - self._file_save_async_or_sync( - state_dict, path=os.path.join(output_dir, weight_filename), is_sync=True, state_dict_type="model_weight" - ) - - if isinstance(model_to_save, PrefixModelForCausalLM): - save_prefix_past_key_value(model_to_save, output_dir) - model_to_save.prefix_config.save_pretrained(output_dir) - if isinstance(model_to_save, LoRAModel): - model_to_save.lora_config.save_pretrained(output_dir) - - config_to_save = save_config(model_to_save) - config_to_save.architectures = [model_to_save.__class__.__name__] - config_to_save.save_pretrained(output_dir) - - # save generation config - if model_to_save.can_generate(): - model_to_save.generation_config.save_pretrained(output_dir) - - def save_single_card_optimizer(self, model, optimizer, output_dir): - """ "Save optimizer for non-distributed environment.""" - # Split into optimizer params and master weights. - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") - - static2struct_name_mappings = {} - state_dict = get_expected_state_dict(model) - fp32_weight = {} - for k, v in state_dict.items(): - static2struct_name_mappings[v.name] = k - if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: - fp32_weight[k] = v - - # rename optimizer param - for key in list(optim_state_dict.keys()): - static_name, type_name = generate_base_static_name(key) - new_name = static2struct_name_mappings[static_name] + "/" + type_name - optim_state_dict[new_name] = optim_state_dict.pop(key) - if master_weights is not None: - for key in list(master_weights.keys()): - master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) - master_weights.update(fp32_weight) - - # save index json - index_optimizer_file, index_master_weight_file = {}, {} - total_optim_size, total_master_weight_size = 0, 0 - for key, weight in optim_state_dict.items(): - index_optimizer_file[key] = "optimizer-00001-of-00001.safetensors" - total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) - if master_weights is not None: - for key, weight in master_weights.items(): - index_master_weight_file[key] = "master_weights-00001-of-00001.safetensors" - total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) - path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME) - master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME) - with open(path, "w") as f: - has_master_weights = master_weights is not None - json.dump( - { - "metadata": {"total_size": total_optim_size}, - "weight_map": index_optimizer_file, - "master_weights": has_master_weights, - }, - f, - indent=4, - ) - if master_weights is not None: - with open(master_path, "w") as f: - json.dump( - {"metadata": {"total_size": total_master_weight_size}, "weight_map": index_master_weight_file}, - f, - indent=4, - ) - - # save optimizer state dict - self._file_save_async_or_sync( - optim_state_dict, - path=os.path.join(output_dir, "optimizer-00001-of-00001.safetensors"), - is_sync=True, - state_dict_type="optimizer_weight", - ) - if master_weights is not None: - self._file_save_async_or_sync( - master_weights, - path=os.path.join(output_dir, "master_weights-00001-of-00001.safetensors"), - is_sync=True, - state_dict_type="master_weight", - ) - - def unlink_shared_memory(self): - if not ("async_save" in self.args.unified_checkpoint_config): - return - - if self._shared_save_model_flag is not None: - while self._shared_save_model_flag[0] > 0: # async process is saving - if not self._process_model_weight.is_alive(): - raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_model_flag[0] = -1 - if self._shared_save_master_weight_flag is not None: - while self._shared_save_master_weight_flag[0] > 0: - if not self._process_master_weight.is_alive(): - raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_master_weight_flag[0] = -1 - if self._shared_save_optimizer_flag is not None: - while self._shared_save_optimizer_flag[0] > 0: - if not self._process_optimizer_weight.is_alive(): - raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_optimizer_flag[0] = -1 - - if self._shm_model_weight is not None: - self._shm_model_weight.close() - self._shm_model_weight.unlink() - self._shm_model_weight = None - if self._shm_master_weight is not None: - self._shm_master_weight.close() - self._shm_master_weight.unlink() - self._shm_master_weight = None - if self._shm_optimizer_weight is not None: - self._shm_optimizer_weight.close() - self._shm_optimizer_weight.unlink() - self._shm_optimizer_weight = None - - if paddle.distributed.get_world_size() > 1: - dist.barrier() - - -def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): - """ - Only dataset_rank == 0 or using expert parallel can enter this function. - """ - index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True) - - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - loaded_keys = sharded_metadata["all_checkpoint_keys"] - - model_state_dict = get_expected_state_dict(model) - # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. - if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): - expected_keys = set(list(model_state_dict.keys())) - else: - expected_keys = set() - for key in model_state_dict.keys(): - if getattr(model_state_dict[key], "no_sync", False): - expected_keys.add(key) - missing_keys = expected_keys - set(loaded_keys) - - use_fast_set = True - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - use_fast_set = False - - if len(missing_keys) > 0: - raise ValueError(f"missing_keys: {missing_keys}") - - def _remove_unused_keys( - state_dict, - model_state_dict, - ): - unused_keys = set(state_dict.keys()) - set(model_state_dict.keys()) - for unused_key in unused_keys: - del state_dict[unused_key] - return unused_keys - - # This should always be a list but, just to be sure. - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] - - error_msgs = [] - - if len(resolved_archive_file) > 1: - resolved_archive_file = tqdm(resolved_archive_file, desc="Loading checkpoint shards") - - for shard_file in resolved_archive_file: - # TODO: check if no expected_keys in shard_file, then don't load it - if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): - continue - - pre_tensor_parallel_split = False - if shard_file.endswith(".safetensors") and model.config.tensor_parallel_degree > 1: - pre_tensor_parallel_split = True - assert loaded_keys is not None, "loaded_keys is not None." - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - set(loaded_keys), is_split=True, ignore_error=True - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True) - # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected" - ) - - if not pre_tensor_parallel_split: - # Since we load all keys but we only need one of pipeline stages - _ = _remove_unused_keys(state_dict, model_state_dict) - - if model.config.tensor_parallel_degree > 1 and not pre_tensor_parallel_split: - logger.info("Converting state_dict to Tensor Parallel Format") - # ignore error for multi shard, since only parts of data - state_dict = model.convert_tensor_parallel( - None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1 - ) - - if use_fast_set: - error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False) - else: - error_msgs += _load_state_dict_into_model(model, state_dict, "") - - # force memory release - del state_dict - # gc.collect() - - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if " but the expected shape is" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - -def save_config(model_to_save): - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.dtype = str(dtype).split(".")[1] - config_to_save = copy.deepcopy(model_to_save.config) - - if config_to_save.tensor_parallel_degree > 1: - # do we need to change? - config_to_save.tensor_parallel_degree = 1 - - return config_to_save - - -def unified_checkpoint_into_shards( - args, - model_to_save, - safe_serialization=False, -): - """Get state_dict and config to save - - Args: - model_to_save (nn.Layer): model to, save - safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False. - - Returns: - tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name. - """ - paddle.device.cuda.empty_cache() - assert hasattr(model_to_save, "config") - - state_dict = get_expected_state_dict(model_to_save) - all_filter_keys = filter_params(model_to_save, state_dict) - - config_to_save = copy.deepcopy(model_to_save.config) - - if config_to_save.tensor_parallel_degree > 1: - if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): - tp_actions = model_to_save._get_tensor_parallel_convert_actions( - all_filter_keys, is_split=False, ignore_error=True - ) - else: - tp_actions = model_to_save.get_tensor_parallel_convert_actions( - model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True - ) - logger.info("Unified model tensor parallel weights in shards") - state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys) - - # build index json file - index_weight_file = {} - total_size = 0 - if isinstance(model_to_save, LoRAModel): - weights_name = SAFE_PEFT_WEIGHTS_NAME if safe_serialization else LORA_WEIGHTS_NAME - elif isinstance(model_to_save, PrefixModelForCausalLM): - weights_name = SAFE_PEFT_WEIGHTS_NAME if safe_serialization else PREFIX_WEIGHTS_NAME - else: - weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME - - shard_file = get_sharded_file_name(args, weights_name) - # renumerize shard_file name for expert_parallel. - if args.use_expert_parallel: - shard_file = rename_shard_file(args, shard_file, weights_name) - - for key, weight in state_dict.items(): - index_weight_file[key] = shard_file - total_size += weight.numel().item() * dtype_byte_size(weight.dtype) - - index_file_list, total_size_list = gather_sharded_object( - index_weight_file, total_size, use_expert_parallel=args.use_expert_parallel - ) - sharded_index = get_sharded_index( - index_file_list, - total_size_list, - ) - if sharded_index is not None: - if isinstance(model_to_save, LoRAModel): - sharded_index["type"] = "lora" - elif isinstance(model_to_save, PrefixModelForCausalLM): - sharded_index["type"] = "ptuning" - - paddle.device.cuda.empty_cache() - - return state_dict, shard_file, sharded_index - - -def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - # init and get optimizer LR_Scheduler - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - - if not safe_serialization: - index_filename, index_filename_master_weights = ( - PADDLE_OPTIMIZER_INDEX_NAME, - PADDLE_MASTER_WEIGHTS_INDEX_NAME, - ) - else: - index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - - resolved_archive_file, sharded_metadata = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - has_master_weights = True if sharded_metadata["master_weights"] else False - - model_state_dict = get_expected_state_dict(model) - model_keys = list(model_state_dict.keys()) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - - expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) - - # This should always be a list but, just to be sure. - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] - - if len(resolved_archive_file) > 1: - resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") - - # update has_master_weights and index_filename_master_weights - # 1. if the master weight exists, only has_master_weights is set True and loaded when needed - # 2. if master weight does not exist, convert model weight to master weight when needed - has_master_weights, index_filename_master_weights = update_master_weight_status( - args, optimizer, has_master_weights, safe_serialization - ) - - if has_master_weights: - returned_optim_state_dict["master_weights"] = {} - - resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), - ) - - expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) - if not isinstance(resolved_archive_file_mw, list): - resolved_archive_file_mw = [resolved_archive_file_mw] - if len(resolved_archive_file_mw) > 1: - resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") - - def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): - returned_state_dict = {} - # load optimizer - for shard_file in resolved_archive_file: - # TODO: check if no expected_keys in shard_file, then don't load it - if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): - continue - - if shard_file.endswith(".safetensors"): - # assert model_keys is not None, "model_keys is None." TODO: correct the assert - if model.config.tensor_parallel_degree > 1: - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - model_keys, is_split=True, ignore_error=True - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions( - model.config, model_keys, ignore_error=True - ) - if not is_master_weights: - tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) - - # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") - else: - # for pipeline model, we don't need to use tp_actions - state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") - - returned_state_dict.update(state_dict) - # force memory release - del state_dict - gc.collect() - return returned_state_dict - - state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) - if has_master_weights: - state_dict_master_weight = load_resolved_archive_file( - resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True - ) - # rename optimizer param - for key in list(state_dict_optim.keys()): - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - returned_optim_state_dict[key_name] = state_dict_optim.pop(key) - returned_optim_state_dict[key_name].name = key_name - - if has_master_weights: - for key in list(state_dict_master_weight.keys()): - static_name = struct2static_name_mappings[key] - returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) - returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - return returned_optim_state_dict - - -def unified_optimizer_into_shards( - args, - model, - optimizer, - safe_serialization=False, -): - """Get optimizer state dict and master weight state dict. - - Args: - optimizer (Optimizer): optimizer to save. - safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False. - """ - paddle.device.cuda.empty_cache() - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict["master_weights"] - optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") - - # gather global master_weights status. - global_master_weights = reduce_master_weights_status(master_weights is not None) - if master_weights is None and global_master_weights: - master_weights = {} - - # get optimizer param mappings - static2struct_name_mappings = {} - state_dict = get_expected_state_dict(model) - fp32_weight = {} - for k, v in state_dict.items(): - static2struct_name_mappings[v.name] = k - if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: - if args.dataset_rank > 0: # deal with different dataset rank. - continue - fp32_weight[k] = v - - # rename optimizer param - for key in list(optim_state_dict.keys()): - static_name, type_name = generate_base_static_name(key) - new_name = static2struct_name_mappings[static_name] + "/" + type_name - optim_state_dict[new_name] = optim_state_dict.pop(key) - if master_weights is not None: - for key in list(master_weights.keys()): - master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) - master_weights.update(fp32_weight) - - # filter optimizer param - if master_weights is not None: - filter_master_keys = filter_params(model, master_weights, is_optimizer=True) - filter_optim_keys = filter_params(model, optim_state_dict, is_optimizer=True) - - tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() - tp_size = tp_group.nranks - - if tp_size > 1: - # get tp_actions - model_keys = [] - for key in optim_state_dict.keys(): - base_model_key = key.split("/")[0] - if base_model_key not in model_keys: - model_keys.append(base_model_key) - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=False, ignore_error=True) - else: - tp_actions = model.get_tensor_parallel_convert_actions( - model.config, model_keys, is_split=False, ignore_error=True - ) - logger.info("Unified optimizer tensor parallel in shards") - optim_state_dict = merge_tensor_parallel_for_optimizer( - optim_state_dict, - tp_actions, - filter_optim_keys, - state_dict if args.use_expert_parallel else None, - ) - paddle.device.cuda.empty_cache() - - if master_weights is not None: - logger.info("Unified master weight tensor parallel in shards") - master_weights = merge_tensor_parallel_for_optimizer( - master_weights, - tp_actions, - filter_master_keys, - state_dict if args.use_expert_parallel else None, - ) - paddle.device.cuda.empty_cache() - - # build index json file - index_optimizer_file, index_master_weight_file = {}, {} - total_optim_size, total_master_weight_size = 0, 0 - optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME - master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME - if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: - master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME - shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True) - shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True) - - for key, weight in optim_state_dict.items(): - index_optimizer_file[key] = shard_optimizer_file - total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) - - if master_weights is not None: - for key, weight in master_weights.items(): - index_master_weight_file[key] = shard_master_weight_file - total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) - - index_optimizer_filelist, total_optim_size_list = gather_sharded_object( - index_optimizer_file, - total_optim_size, - is_optimizer=True, - use_expert_parallel=args.use_expert_parallel, - ) - sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) - if master_weights is not None: - index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( - index_master_weight_file, - total_master_weight_size, - is_optimizer=True, - use_expert_parallel=args.use_expert_parallel, - ) - sharded_master_weight_index = get_sharded_index(index_master_weight_filelist, total_master_weight_size_list) - - if sharded_optim_index is not None: - if master_weights is not None: - sharded_optim_index["master_weights"] = True - else: - sharded_optim_index["master_weights"] = False - - paddle.device.cuda.empty_cache() - if master_weights is None: - return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)] - else: - return [ - (optim_state_dict, shard_optimizer_file, sharded_optim_index), - (master_weights, shard_master_weight_file, sharded_master_weight_index), - ] - - -def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): - index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False) - index_filename = os.path.join(resume_from_checkpoint, index_filename) - # Find index json file and distribute this file in global group. - if distributed_isfile(index_filename): - distributed_file(index_filename) - else: - raise Exception( - f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." - ) - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - all_weight_filenames = sorted(set(index["weight_map"].values())) - - # Get existed weight file list on current machine. - existed_filelist = [] - existed_files = [] - for filename in os.listdir(resume_from_checkpoint): - if filename in all_weight_filenames: - existed_files.append(filename) - - # Gather all the existed files in global group. - dist.all_gather_object(existed_filelist, existed_files) - flatten_existed_filelist = flatten_list(existed_filelist) - diff_filelist = list(set(all_weight_filenames).difference(set(flatten_existed_filelist))) - if len(diff_filelist) != 0: - raise Exception(f"Sorry, the weight file list on the machines is not complete!, missing {diff_filelist}") - - # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. - local_resume = True - if args.dataset_rank == 0 or args.use_expert_parallel: - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - dp_group = hcg.get_data_parallel_group() - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - - need_files = set() - state_dict = get_expected_state_dict(model) - for key in state_dict.keys(): - filename = index["weight_map"][key] - # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. - if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): - continue - need_files.add(filename) - diff_filelist = list(need_files.difference(set(existed_files))) - num_diff = paddle.to_tensor([len(diff_filelist)]) - if tp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) - if pp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) - if args.use_expert_parallel and dp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) - if num_diff.item() == 0: - local_resume = True - else: - local_resume = False - local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) - local_resume = local_resume.item() - return local_resume - - -def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - if not safe_serialization: - index_filename, index_filename_master_weights = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME - else: - index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - index_filename = os.path.join(resume_from_checkpoint, index_filename) - index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) - - # Find index json file and distribute the file in global group. - if distributed_isfile(index_filename): - distributed_file(index_filename) - else: - raise Exception( - f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." - ) - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - all_optimizer_filenames = sorted(set(index["weight_map"].values())) - - has_master_weights = index["master_weights"] - # update has_master_weights and index_filename_master_weights - # 1. if the master weight exists, only has_master_weights is set True and loaded when needed - # 2. if master weight does not exist, convert model weight to master weight when needed - has_master_weights, index_filename_master_weights = update_master_weight_status( - args, optimizer, has_master_weights, safe_serialization - ) - if has_master_weights: - index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) - if distributed_isfile(index_filename_master_weights): - distributed_file(index_filename_master_weights) - else: - raise Exception( - f"Sorry, we can not find {index_filename_master_weights}. This file should be appear at least on one machine." - ) - with open(index_filename_master_weights, "r") as f: - index_mw = json.loads(f.read()) - all_mw_filenames = sorted(set(index_mw["weight_map"].values())) - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - dp_group = hcg.get_data_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() - sharding_rank = sharding_group.rank - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} - if sharding_group.nranks > 1: - param2rank = optimizer._param2rank - - def check_complete(all_filenames): - # Check whether the checkpoint files on machines are complete. If not complete, raise Exception. - existed_filelist = [] - existed_files = [] - for filename in os.listdir(resume_from_checkpoint): - if filename in all_filenames: - existed_files.append(filename) - - dist.all_gather_object(existed_filelist, existed_files) - flatten_existed_filelist = flatten_list(existed_filelist) - diff_filelist = list(set(all_filenames).difference(set(flatten_existed_filelist))) - if len(diff_filelist) != 0: - raise Exception( - f"Sorry, the optimizer file list on `data_parallel_rank==0` machines is not complete!, missing {diff_filelist}" - ) - return existed_files - - def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): - # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. - local_resume = True - if args.data_parallel_rank == 0 or args.use_expert_parallel: - need_files = set() - state_dict = get_expected_state_dict(model) - - for key in state_dict.keys(): - if sharding_group.nranks > 1: - static_name = struct2static_name_mappings.get(key, None) - param_rank = param2rank.get(static_name, None) - if param_rank != sharding_rank: - continue - - # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. - if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): - continue - - if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: - continue - - if not is_master_weights: - for type_name in typename_set: - type_key = key + "/" + type_name - filename = weight_map[type_key] - need_files.add(filename) - else: - filename = weight_map[key] - need_files.add(filename) - - diff_filelist = list(need_files.difference(set(existed_files))) - num_diff = paddle.to_tensor([len(diff_filelist)]) - if tp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) - if pp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) - if sharding_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) - if args.use_expert_parallel and dp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) - - if num_diff.item() == 0: - local_resume = True - else: - local_resume = False - local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) - return local_resume.item() - - # check whether the optimizer checkpoint files are complete. - existed_files = check_complete(all_optimizer_filenames) - if has_master_weights: - existed_files_mw = check_complete(all_mw_filenames) - # get optimizer's param type name, like moment1_0. - typename_set = set() - for key in index["weight_map"].keys(): - _, typename = key.split("/") - typename_set.add(typename) - local_resume = check_dynamic_load( - args, index["weight_map"], existed_files, is_master_weights=False, typename_set=typename_set - ) - local_resume_rw = True - if has_master_weights: - local_resume_rw = check_dynamic_load(args, index_mw["weight_map"], existed_files_mw, is_master_weights=True) - return local_resume & local_resume_rw - - -def save_prefix_past_key_value(model_to_save, save_directory): - past_key_value = model_to_save.prefix_encoder(model_to_save.prefix_tokens.unsqueeze(0).expand([1, -1])) - past_key_value = past_key_value.reshape( - [ - model_to_save.prefix_config.num_prefix_tokens, - 2, - model_to_save.prefix_config.num_hidden_layers, - model_to_save.num_heads, - model_to_save.head_dim, - ] - ) - past_key_value = paddle.transpose(past_key_value, perm=[2, 1, 3, 0, 4]).cpu().numpy() - model_to_save.prefix_config.save_pretrained(save_directory) - np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) - - -def get_expected_state_dict(model_to_save): - if isinstance(model_to_save, PretrainedModel): - state_dict = model_to_save.state_dict() - if ( - hasattr(model_to_save.config, "tie_word_embeddings") - and model_to_save.config.tie_word_embeddings - and hasattr(model_to_save, "_tied_weights_keys") - and model_to_save._tied_weights_keys is not None - ): - for key in model_to_save._tied_weights_keys: - if key in state_dict: - state_dict.pop(key) - elif isinstance(model_to_save, LoRAModel): - state_dict = model_to_save.get_trainable_state_dict() - elif isinstance(model_to_save, PrefixModelForCausalLM): - state_dict = model_to_save.prefix_encoder.state_dict() - - return state_dict - - -def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappings, resume_from_checkpoint): - """Create dispatch table for dynamically loading state dict. - - Args: - args - """ - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - tp_rank = tp_group.rank - - # Create tensor receive table, contains {"key0": [global_rank, tp_rank], "key1": [global_rank, tp_rank]} - dispatch_list = [] - recv_table = {} - if args.dataset_rank == 0: - state_dict = get_expected_state_dict(model) - for (k, v) in state_dict.items(): - if hasattr(v, "is_distributed") and v.is_distributed: - recv_table[k] = [(dist.get_rank(), tp_rank)] - else: - recv_table[k] = [(dist.get_rank(), -1)] - - # Gather receive table in global group. - dist.all_gather_object(dispatch_list, recv_table) - recv_table = {} - for dl in dispatch_list: - for key, value in dl.items(): - if key not in recv_table: - recv_table[key] = value - else: - recv_table[key] += value - - # Create send table, to decide which worker to send the key. Contains {"key0:" global_rank, "key1": global_rank, ...} - send_table = create_send_table(file_keyname_mappings, file_machine_mappings) - - return send_table, recv_table - - -def create_optimizer_dispatch_table( - args, - model, - optimizer, - file_keyname_mappings, - file_machine_mappings, - resume_from_checkpoint, - struct2static_name_mappings, - is_master_weights=False, - typename_set=None, -): - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() - sharding_rank = sharding_group.rank - if sharding_group.nranks > 1: - param2rank = optimizer._param2rank - tp_rank = tp_group.rank - - # Create receive table, contains {"param_key0": [global_rank, tp_rank], "param_key1": [global_rank, tp_rank]} - dispatch_list = [] - recv_table = {} - if args.data_parallel_rank == 0: - state_dict = get_expected_state_dict(model) - for (k, v) in state_dict.items(): - if sharding_group.nranks > 1: - static_name = struct2static_name_mappings[k] - param_rank = param2rank.get(static_name, None) - if param_rank != sharding_rank: - continue - if is_master_weights: - if hasattr(v, "is_distributed") and v.is_distributed: - recv_table[k] = [(dist.get_rank(), tp_rank)] - else: - recv_table[k] = [(dist.get_rank(), -1)] - else: - for typename in typename_set: - type_key = k + "/" + typename - if typename in optimizer_non_scaler_name: - if hasattr(v, "is_distributed") and v.is_distributed: - recv_table[type_key] = [(dist.get_rank(), tp_rank)] - else: - recv_table[type_key] = [(dist.get_rank(), -1)] - else: - recv_table[type_key] = [(dist.get_rank(), -1)] - - dist.all_gather_object(dispatch_list, recv_table) - recv_table = {} - for dl in dispatch_list: - for k, v in dl.items(): - if k not in recv_table: - recv_table[k] = v - else: - recv_table[k] += v - - # Create send table, to decide which worker to send the key. Contains {"param_key0:" 0, "param_key1": 1, ...} - send_table = create_send_table(file_keyname_mappings, file_machine_mappings) - return send_table, recv_table - - -def load_unified_checkpoint_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False) - index_filename = os.path.join(resume_from_checkpoint, index_filename) - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - - # `file_keyname_mappings` indicates which keys each file contains. For example, {"model-00001-of-00002.safetensors": ["llama.embed_tokens.weight", "llama.layers.0.self_attn.q_proj.weight", ...]} - # `file_machine_mappings` indicates the machine where the files appear. For example, {"model-00001-of-00002.safetensors": [machine_0, machine_1], "model-00002-of-00002.safetensors": [machine_0]} - file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) - - logger.debug("Creating dispatch table for unified checkpoint load ...") - # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. - send_table, recv_table = create_dispatch_table( - args, model, file_keyname_mappings, file_machine_mappings, resume_from_checkpoint - ) - - # Get all the keys that are splited by tensor parallelism. - all_tp_keys = set() - for k, v in recv_table.items(): - if v[0][1] != -1: - all_tp_keys.add(k) - - config_revise = copy.deepcopy(model.config) - config_revise.tensor_parallel_rank = None - if len(all_tp_keys) == 0: - tp_actions = {} - else: - # Get corresponding tensor parallel actions. - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) - - logger.debug("Distributed send recv for state dict load ...") - # Distribute the checkpoint tensor dynamically, using the `send_table` and `recv_table` we create before. - state_dict = distributed_send_recv( - config_revise, - get_expected_state_dict(model), - tp_actions, - send_table, - recv_table, - resume_from_checkpoint, - file_keyname_mappings, - file_machine_mappings, - ) - dist.barrier() - logger.debug("Setting state dict into model ...") - error_msgs = _load_state_dict_into_model(model, state_dict, "") - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - raise RuntimeError(f"Error(s) in loading dynamic state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - -def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - optim_state_dict = nested_copy(optimizer.state_dict()) - if "master_weights" in optim_state_dict.keys(): - optim_state_dict.pop("master_weights") - - if safe_serialization: - index_filename, index_filename_mw = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - else: - index_filename, index_filename_mw = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME - - with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: - index = json.loads(f.read()) - - # `file_keyname_mappings` indicates which keys each file contains. For example, {"optimizer-00001-of-00002.safetensors": ["llama.embed_tokens.weight/moment1_0", "llama.layers.1.mlp.gate_proj.weight/moment1_0", ...]} - # `file_machine_mappings` indicates the machine where the files appear. For example, {"optimizer-00001-of-00002.safetensors": [machine_0, machine_1], "optimizer-00002-of-00002.safetensors": [machine_0]} - file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) - - has_master_weights = index["master_weights"] - # update has_master_weights and index_filename_master_weights - # 1. if the master weights exists, only has_master_weights is set True and load master weights when needed - # 2. if master weights does not exist, convert model weights to master weights when needed - has_master_weights, index_filename_mw = update_master_weight_status( - args, optimizer, has_master_weights, safe_serialization - ) - - if has_master_weights: - with open(os.path.join(resume_from_checkpoint, index_filename_mw), "r") as f: - index_mw = json.loads(f.read()) - file_keyname_mappings_mw, file_machine_mappings_mw = get_file_mappings(index_mw, resume_from_checkpoint) - - # Get optimizer param type name, like moment1_0, moment2_0, beta1_pow_acc_0. - typename_set = set() - for key in index["weight_map"].keys(): - _, typename = key.split("/") - typename_set.add(typename) - - model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} - static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} - # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. - send_table, recv_table = create_optimizer_dispatch_table( - args, - model, - optimizer, - file_keyname_mappings, - file_machine_mappings, - resume_from_checkpoint, - struct2static_name_mappings, - is_master_weights=False, - typename_set=typename_set, - ) - if has_master_weights: - send_table_mw, recv_table_mw = create_optimizer_dispatch_table( - args, - model, - optimizer, - file_keyname_mappings_mw, - file_machine_mappings_mw, - resume_from_checkpoint, - struct2static_name_mappings, - is_master_weights=True, - ) - - # Initialize optimizer state dict. - hcg = fleet.get_hybrid_communicate_group() - sharding_group = hcg.get_sharding_parallel_group() - if sharding_group.nranks > 1: - param2rank = optimizer._param2rank - optim_state_dict_mw = {} - - def check_optimizer_param(parameter): - if sharding_group.nranks > 1: - param_rank = param2rank.get(parameter.name, None) - if param_rank != sharding_group.rank: - return False - if parameter.stop_gradient: - return False - return True - - optimizer_keys_with_shape = [] - if isinstance(optimizer._parameter_list[0], dict): - for param_group in optimizer._parameter_list: - # If parameter groups are set, there must be `params` key. This is guaranteed by the optimizer's initialization code. - for parameter in param_group["params"]: - if check_optimizer_param(parameter): - optimizer_keys_with_shape.append((parameter.name, parameter.shape)) - else: - for parameter in optimizer._parameter_list: - if check_optimizer_param(parameter): - optimizer_keys_with_shape.append((parameter.name, parameter.shape)) - - # see how to change - for static_name, shape in optimizer_keys_with_shape: - k = static2struct_name_mappings[static_name] - for typename in typename_set: - new_k = k + "/" + typename - if typename in optimizer_scalar_name: - optim_state_dict[new_k] = paddle.empty([1], dtype="float32") - else: - optim_state_dict[new_k] = paddle.empty(shape, dtype="float32") - if has_master_weights: - optim_state_dict_mw[k] = paddle.empty(shape, dtype="float32") - - # Get all the keys that are splited by tensor parallelism. - all_tp_keys = set() - for k, v in recv_table.items(): - structure_name, typename = k.split("/") - if typename in optimizer_non_scaler_name: - if v[0][1] != -1: - all_tp_keys.add(structure_name) - - # Get corresponding tensor parallel actions. - config_revise = copy.deepcopy(model.config) - config_revise.tensor_parallel_rank = None - if len(all_tp_keys) == 0: - tp_actions = {} - else: - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) - optimizer_keys = list(index["weight_map"].keys()) - optimizer_tp_actions = mapping_optimizer_tp_actions(tp_actions, optimizer_keys) - if has_master_weights: - optimizer_tp_actions.update(tp_actions) - - # Distribute the optimizer checkpoint dynamically, using the `send_table` and `recv_table` we create before. - optim_state_dict = distributed_send_recv( - config_revise, - optim_state_dict, - optimizer_tp_actions, - send_table, - recv_table, - resume_from_checkpoint, - file_keyname_mappings, - file_machine_mappings, - ) - dist.barrier() - if has_master_weights: - optim_state_dict_mw = distributed_send_recv( - config_revise, - optim_state_dict_mw, - optimizer_tp_actions, - send_table_mw, - recv_table_mw, - resume_from_checkpoint, - file_keyname_mappings_mw, - file_machine_mappings_mw, - ) - dist.barrier() - - # Rename optimizer state dict. - for key in list(optim_state_dict.keys()): - if key == "LR_Scheduler": - continue - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - optim_state_dict[key_name] = optim_state_dict.pop(key) - optim_state_dict[key_name].name = key_name - - if has_master_weights: - optim_state_dict["master_weights"] = {} - for key in list(optim_state_dict_mw.keys()): - static_name = struct2static_name_mappings[key] - optim_state_dict["master_weights"][static_name] = optim_state_dict_mw.pop(key) - optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - if args.data_parallel_rank == 0: - return optim_state_dict - return None - - -def load_single_card_checkpoint(args, model, resume_from_checkpoint: str): - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME - else: - index_filename = SAFE_WEIGHTS_INDEX_NAME - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - - loaded_keys = sharded_metadata["all_checkpoint_keys"] - model_state_dict = get_expected_state_dict(model) - expected_keys = set(list(model_state_dict.keys())) - missing_keys = expected_keys - set(loaded_keys) - - if len(missing_keys) > 0: - raise ValueError(f"Missing keys: {missing_keys}") - - state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys) - error_msgs = _load_state_dict_into_model(model, state_dict, "") - del state_dict - gc.collect() - - if error_msgs: - raise RuntimeError(f"Error(s) in loading state dict for {model.__class__.__name__}:\n\t{error_msgs}") - - -def load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint: str): - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - - resolved_archive_file, sharded_metadata = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), - ) - has_master_weights = True if sharded_metadata["master_weights"] else False - - model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} - expected_keys = sharded_metadata["all_optimizer_keys"] - - if has_master_weights: - returned_optim_state_dict["master_weights"] = {} - resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, SAFE_MASTER_WEIGHTS_INDEX_NAME), - ) - expected_keys_mw = sharded_metadata_mw["all_optimizer_keys"] - - state_dict_optim = load_state_dict(resolved_archive_file[0], None, expected_keys) - if has_master_weights: - state_dict_optim_mw = load_state_dict(resolved_archive_file_mw[0], None, expected_keys_mw) - - for key in list(state_dict_optim.keys()): - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - returned_optim_state_dict[key_name] = state_dict_optim.pop(key) - returned_optim_state_dict[key_name].name = key_name - if has_master_weights: - for key in list(state_dict_optim_mw.keys()): - static_name = struct2static_name_mappings[key] - returned_optim_state_dict["master_weights"][static_name] = state_dict_optim_mw.pop(key) - returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - returned_optim_state_dict = nested_copy_place( - returned_optim_state_dict, - place=paddle.framework._current_expected_place(), - blocking=True, - ) - return returned_optim_state_dict - - -def get_file_mappings(index, resume_from_checkpoint): - file_keyname_mappings = {} - for k, v in index["weight_map"].items(): - if v not in file_keyname_mappings: - file_keyname_mappings[v] = [] - file_keyname_mappings[v].append(k) - for k in file_keyname_mappings.keys(): - file_keyname_mappings[k] = sorted(file_keyname_mappings[k]) - - local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) - global_rank = dist.get_rank() - file_machine_mappings = {} - for filename in file_keyname_mappings.keys(): - if local_rank == 0 and os.path.exists(os.path.join(resume_from_checkpoint, filename)): - file_machine_mappings[filename] = [global_rank // local_device_count] - file_machine_list = [] - dist.all_gather_object(file_machine_list, file_machine_mappings) - file_machine_mappings = {} - for mappings in file_machine_list: - for k, v in mappings.items(): - if k not in file_machine_mappings: - file_machine_mappings[k] = v - else: - file_machine_mappings[k] += v - return file_keyname_mappings, file_machine_mappings - - -def create_send_table(file_keyname_mappings, file_machine_mappings): - send_table = {} - global_rank = dist.get_rank() - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) - local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) - for filename, keys in file_keyname_mappings.items(): - machine = file_machine_mappings[filename][0] - is_src = (global_rank // local_device_count) == machine - for i, key in enumerate(keys): - if is_src and local_rank == i % local_device_count: - send_table[key] = global_rank - dispatch_list = [] - dist.all_gather_object(dispatch_list, send_table) - send_table = {} - for dl in dispatch_list: - send_table.update(dl) - return send_table - - -def distributed_send_recv( - config, - state_dict, - tp_actions, - send_table, - recv_table, - resume_from_checkpoint, - file_keyname_mappings, - file_machine_mappings, -): - - local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) - global_rank = dist.get_rank() - for filename in file_keyname_mappings.keys(): - machine = file_machine_mappings[filename][0] - is_src = global_rank // local_device_count == machine - if is_src: - f = safe_open(os.path.join(resume_from_checkpoint, filename), framework="np") - - for key in file_keyname_mappings[filename]: - recv_info = recv_table[key] - recv_ranklist = [a for (a, b) in recv_info] - if is_src and global_rank == send_table[key]: - py_safe_slice_ = f.get_slice(key) - # send - if key in tp_actions: - weight = tp_actions[key](py_safe_slice_) - # copy weight to GPU - for j in range(len(weight)): - with device_guard(): - weight[j] = paddle.Tensor(weight[j], zero_copy=True) - weight[j] = weight[j]._copy_to(paddle.framework._current_expected_place(), False) - - for recv_rank, split_index in recv_info: - if recv_rank == global_rank: - state_dict[key] = weight[split_index] - else: - dist.stream.send(weight[split_index], dst=recv_rank) - else: - # no need to tp split - weight = py_safe_slice_[:] - with device_guard(): - weight = paddle.Tensor(weight, zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - for recv_rank, _ in recv_info: - if recv_rank == global_rank: - state_dict[key] = weight - else: - dist.stream.send(weight, dst=recv_rank) - - if global_rank != send_table[key] and global_rank in recv_ranklist: - dist.stream.recv(state_dict[key], src=send_table[key]) - - if is_src: - f.__exit__(None, None, None) - - return state_dict - - -def get_sharded_file_name(args, file_name, is_optimizer=False): - if not is_optimizer: - sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 - size = sd_degree if args.use_expert_parallel else args.dataset_world_size - shard_file = file_name.replace( - ".pdparams", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams", - ) - shard_file = shard_file.replace( - ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", - ) - else: - hcg = fleet.get_hybrid_communicate_group() - dp_group = hcg.get_data_parallel_group() - size = dp_group.nranks if not args.use_expert_parallel else 1 - shard_file = file_name.replace( - ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" - ) - shard_file = shard_file.replace( - ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", - ) - shard_file = shard_file.replace( - ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdopt" - ) - return shard_file - - -def get_sharded_index( - index_file_list, - total_size_list, -): - # save index json file - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) - if local_rank == 0: - sharded_index_json = {} - - sharded_index_json["metadata"] = {"total_size": sum(total_size_list)} - - weight_map = {} - for i, index_file in enumerate(index_file_list): - weight_map.update(index_file_list[i]) - - sharded_index_json["weight_map"] = weight_map - return sharded_index_json - - return None - - -def reduce_master_weights_status(has_master_weights=False): - data = paddle.to_tensor([has_master_weights], dtype="int32") - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() - - if tp_group.nranks > 1: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=tp_group) - if pp_group.nranks > 1: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pp_group) - if sharding_group.nranks > 1: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=sharding_group) - - return data.item() > 0 - - -def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): - - index_file_list, total_size_list = [], [] - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - - logger.info( - f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}." - ) - - if tp_group.nranks > 1: - dist.all_gather_object(index_file_list, index_file, tp_group) - dist.all_gather_object(total_size_list, total_size, tp_group) - if pp_group.nranks > 1: - pp_index_file_list = [] - pp_total_size_list = [] - dist.all_gather_object( - pp_index_file_list, index_file_list if len(index_file_list) > 0 else index_file, pp_group - ) - dist.all_gather_object( - pp_total_size_list, total_size_list if len(total_size_list) > 0 else total_size, pp_group - ) - index_file_list = pp_index_file_list - total_size_list = pp_total_size_list - - index_file_list = flatten_list(index_file_list) - total_size_list = flatten_list(total_size_list) - - # for pure sharding - if len(index_file_list) == 0 and len(total_size_list) == 0: - index_file_list = [index_file] - total_size_list = [total_size] - - if use_expert_parallel: - data_group = hcg.get_data_parallel_group() - if data_group.nranks > 1: - data_index_file_list = [] - data_total_size_list = [] - dist.all_gather_object(data_index_file_list, index_file_list, data_group) - dist.all_gather_object(data_total_size_list, total_size_list, data_group) - index_file_list = flatten_list(data_index_file_list) - total_size_list = flatten_list(data_total_size_list) - - if is_optimizer: - sharding_group = hcg.get_sharding_parallel_group() - if sharding_group.nranks > 1: - sharding_index_file_list = [] - sharding_total_size_list = [] - dist.all_gather_object(sharding_index_file_list, index_file_list, sharding_group) - dist.all_gather_object(sharding_total_size_list, total_size_list, sharding_group) - index_file_list = flatten_list(sharding_index_file_list) - total_size_list = flatten_list(sharding_total_size_list) - - return index_file_list, total_size_list - - -def rename_shard_file(args, shard_file, file_name): - """rename shard file when using expert_parallel.""" - assert args.use_expert_parallel, "only expert_parallel need to use this function" - - shard_file_list = [] - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - data_group = hcg.get_data_parallel_group() - - if tp_group.nranks > 1: - dist.all_gather_object(shard_file_list, shard_file, tp_group) - if pp_group.nranks > 1: - pp_shard_file_list = [] - dist.all_gather_object( - pp_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, pp_group - ) - shard_file_list = flatten_list(pp_shard_file_list) - if data_group.nranks > 1: - data_shard_file_list = [] - dist.all_gather_object( - data_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, data_group - ) - shard_file_list = flatten_list(data_shard_file_list) - - new_index = shard_file_list.index(shard_file) - sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 - shard_file = file_name.replace( - ".pdparams", - f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.pdparams", - ) - shard_file = shard_file.replace( - ".safetensors", - f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", - ) - return shard_file - - -def generate_base_static_name(vname): - # return base static name and specific type name, like [embedding_0.w_0, moment1_0] - if FP32_MASTER in vname: - vname = vname.split("_" + FP32_MASTER + "_") - return vname[0], vname[1] - else: - # Directly deal with type names, for example: moe_gate_1_moment1_0. - type_names = optimizer_scalar_name + optimizer_non_scaler_name - for name in type_names: - if name in vname: - a = vname.split(name)[0][:-1] - b = name - return a, b - - -def filter_params(model_to_save, state_dict, is_optimizer=False): - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - - tp_size = tp_group.nranks - tp_rank = tp_group.rank - - # for pure sharding or pure pp - if tp_size <= 1: - return [list(state_dict.keys())] - - filter_tensor_list = [[] for i in range(tp_size)] - - if tp_rank == 0: - tensor_bytes_dict = {} - model_state_dict = get_expected_state_dict(model_to_save) - for (k, v) in state_dict.items(): - model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v - if hasattr(model_v, "is_distributed") and model_v.is_distributed: - tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) - else: - tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) - - filter_tensor_list = [] - current_block = [] - current_block_size = 0 - total_size = 0 - - max_shard_size = (sum(tensor_bytes_dict.values()) + tp_size - 1) // tp_size - - for index, (key, weight_size) in enumerate(tensor_bytes_dict.items()): - # If this weight is going to tip up over the maximal size, we split. - # if current_block_size + weight_size > max_shard_size: - if total_size + weight_size > max_shard_size * (len(filter_tensor_list) + 1) or ( - len(tensor_bytes_dict) - index < (tp_size - len(filter_tensor_list)) - ): - # fix if the first param is large than max_shard_size - if len(current_block) > 0: - filter_tensor_list.append(current_block) - current_block = [] - current_block_size = 0 - - current_block.append(key) - current_block_size += weight_size - total_size += weight_size - - filter_tensor_list.append(current_block) - if len(filter_tensor_list) < tp_size: - filter_tensor_list.extend([[] for i in range(tp_size - len(filter_tensor_list))]) - - dist.broadcast_object_list( - filter_tensor_list, - src=hcg.get_model_parallel_group_src_rank(), - group=tp_group, - ) - - return filter_tensor_list - - -def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): - num_rows = tensor.shape[0] - num_splits = 4 - parts = np.array_split(np.arange(num_rows), num_splits) - splits = [len(part) for part in parts] - split_parts = np.insert(np.cumsum(splits), 0, 0) - split_tensors = [] - for i in range(num_splits): - if get_env_device() == "xpu": - ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False) - else: - ret = distributed_gather( - tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False - ) - # Copy to CPUPlace temporarily, may lower speed. - if ret is not None: - ret = [t.cpu() for t in ret] - split_tensors.append(ret) - concat_tensors = [] - if is_dst: - for i in range(tp_group.nranks): - tmp = [] - for j in range(num_splits): - tmp.append(split_tensors[j][i]) - concat_tensors.append(paddle.concat(tmp)) - tensor = tp_action(concat_tensors) - else: - tensor = None - return tensor - - -def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - dp_group = hcg.get_data_parallel_group() - tp_rank = tp_group.rank - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - - # filter actions for pipeline mode - if hcg.get_pipe_parallel_group().nranks > 1: - filter_keys = set([y for x in all_filter_keys for y in x]) - for key in list(tp_actions.keys()): - if key not in filter_keys: - tp_actions.pop(key) - - state_dict_to_save = {} - max_key_len = max([len(_) for _ in all_filter_keys]) - for i in range(max_key_len): - for j, filter_keys in enumerate(all_filter_keys): - is_dst = tp_rank == j - if i > len(filter_keys) - 1: - continue - key = filter_keys[i] - tensor = state_dict[key] - # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. - if dp_rank > 0 and not getattr(tensor, "no_sync", False): - continue - if key in tp_actions: - # Get tensor size - tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks - if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold - tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst) - else: - if get_env_device() == "xpu": - ret = distributed_allgather(tensor, group=tp_group, offload=False) - else: - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) - action = tp_actions.pop(key) - tensor = action(ret) if is_dst else None - else: - if is_dst: - tensor = tensor._copy_to(DEST_PLACE, False) if tensor.place.is_cpu_place() else tensor - else: - tensor = None - - if is_dst: - state_dict_to_save[key] = tensor - - if len(tp_actions) > 0: - for x in tp_actions.keys(): - logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") - - return state_dict_to_save - - -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): - # Core function for UC - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - dp_group = hcg.get_data_parallel_group() - tp_rank = tp_group.rank - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - - no_sync_kname = [] - if model_state_dict is not None: - for k, v in model_state_dict.items(): - if getattr(v, "no_sync", False): - no_sync_kname.append(k) - - state_dict_to_save = {} - max_key_len = max([len(_) for _ in all_filter_keys]) - for i in range(max_key_len): - for j, filter_keys in enumerate(all_filter_keys): - is_dst = tp_rank == j - if i > len(filter_keys) - 1: - continue - # get base model key - model_key = filter_keys[i].split("/")[0] - tensor = state_dict[filter_keys[i]] - # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. - if dp_rank > 0 and model_key not in no_sync_kname: - continue - if model_key in tp_actions: - # for example: beta1, beta2 - if tensor.numel().item() == 1: - if is_dst: - tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor - else: - tensor = None - else: - # Get tensor size - tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks - if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold - tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst) - else: - if get_env_device() == "xpu": - ret = distributed_allgather(tensor, group=tp_group, offload=False) - else: - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) - action = tp_actions[model_key] - tensor = action(ret) if is_dst else None - else: - if is_dst: - tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor - else: - tensor = None - - if is_dst: - state_dict_to_save[filter_keys[i]] = tensor - - return state_dict_to_save - - -def get_optimizer_shard_files(optimizer_path, index_filename): - """ - For a given model: - - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the - Hub - - returns the list of paths to all the shards, as well as some metadata. - For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the - index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). - """ - - import json - - if not os.path.isfile(index_filename): - raise ValueError(f"Can't find a optimizer index ({index_filename}) in {optimizer_path}.") - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - - shard_filenames = sorted(set(index["weight_map"].values())) - sharded_metadata = index["metadata"] - sharded_metadata["all_optimizer_keys"] = list(index["weight_map"].keys()) - sharded_metadata["weight_map"] = index["weight_map"].copy() - sharded_metadata["master_weights"] = index.get("master_weights", False) - - file_map = {file: set() for file in shard_filenames} - for weight, file in index["weight_map"].items(): - file_map[file].add(weight) - - sharded_metadata["file_map"] = file_map - - # First, let's deal with local folder. - # TODO: if optimizer_path is a folder, we should check if the optimizer is already cached or not. - if os.path.isdir(optimizer_path): - shard_filenames = [os.path.join(optimizer_path, f) for f in shard_filenames] - return shard_filenames, sharded_metadata - - -def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weights=False): - hcg = fleet.get_hybrid_communicate_group() - sharding_group = hcg.get_sharding_parallel_group() - sharding_rank = sharding_group.rank - in_sharding_parallel_model = sharding_group.nranks > 1 - if in_sharding_parallel_model: - params2rank = optimizer._param2rank - - model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} - - expected_keys = [] - for key in list(sharded_metadata["all_optimizer_keys"]): - key_name = key.split("/")[0] - if ( - is_master_weights - and key_name in model_state_dict - and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 - ): - continue - - if args.use_expert_parallel and args.data_parallel_rank > 0: - if key_name in model_state_dict and not getattr(model_state_dict[key_name], "no_sync", False): - continue - - static_name = struct2static_name_mappings.get(key_name, None) - - if in_sharding_parallel_model: - params_rank = params2rank.get(static_name, None) - if params_rank == sharding_rank: - expected_keys.append(key) - else: - if static_name is not None: - expected_keys.append(key) - expected_keys = set(expected_keys) - - loaded_keys = sharded_metadata["all_optimizer_keys"] - missing_keys = expected_keys - set(loaded_keys) - if len(missing_keys) > 0: - raise ValueError(f"optimizer missing weights keys: {missing_keys}") - - return expected_keys - - -def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys): - """# convert param.name to - param.key/moment1_0 - or param.key/beta1_XXX - or param.key/beta2_XXX - Args: - tp_actions (dict): dictionay of tensor parallel actions {key: action} - optimizer_loaded_keys (list or set): [param.key1/moment1_0, param.key2/beta1_XXX, param.key3/beta2_XXX] - Returns: - dict: new dictionay of tensor parallel actions {key: action} - """ - new_actions = {} - for key in optimizer_loaded_keys: - key_base, typename = key.split("/") - if typename in optimizer_non_scaler_name and key_base in tp_actions: - new_actions[key] = tp_actions[key_base] - return new_actions - - -def flatten_list(nested_list): - flattened_list = [] - for item in nested_list: - if isinstance(item, list): - flattened_list.extend(flatten_list(item)) - else: - flattened_list.append(item) - return flattened_list - - -def select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True): - """ - try select model weight index from model weight or master weight index. - """ - - # find model weight index file - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_PEFT_WEIGHTS_INDEX_NAME - else: - index_filename = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME - - index_filename_path = os.path.join(resume_from_checkpoint, index_filename) - identify_func = os.path.isfile if local else distributed_isfile - - if identify_func(index_filename_path): - return index_filename - else: - index_filename = PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME - index_filename_path = os.path.join(resume_from_checkpoint, index_filename) - - if identify_func(index_filename_path): - return index_filename - else: - raise ValueError("Can't find a valid unified model or master weight checkpoint to load.") - - -def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization): - if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): - if not has_master_weight: - if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config: - index_filename_master_weights = ( - PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME - ) - has_master_weight = True - logger.warning( - "The unified checkpoint does not contain master weight, " - "the model weight will be loaded as master weight." - ) - else: - raise ValueError( - "Can't find a valid unified master weight checkpoint," - f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to " - "load model checkpoint as master weight" - ) - else: - has_master_weight = True - index_filename_master_weights = ( - PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME - ) - if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: - index_filename_master_weights = ( - PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME - ) - else: - has_master_weight = False - index_filename_master_weights = None - - return has_master_weight, index_filename_master_weights - - -def unwrap_optimizer(optimizer): - while hasattr(optimizer, "_inner_opt") or hasattr(optimizer, "_optim"): - if hasattr(optimizer, "_inner_opt"): - optimizer = optimizer._inner_opt - if hasattr(optimizer, "_optim"): - optimizer = optimizer._optim - - return optimizer - - -def is_need_master_weight(optimizer, is_fp16_or_bp16): - optimizer = unwrap_optimizer(optimizer) - if hasattr(optimizer, "_multi_precision"): - return optimizer._multi_precision and is_fp16_or_bp16 - else: - return False diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index a1fb78de63ba..ba280a1832d5 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -113,7 +113,6 @@ from .argparser import strtobool from .integrations import get_reporting_integration_callbacks from .plugins.timer import RuntimeTimer, get_timers, set_timers -from .plugins.unified_checkpoint import UnifiedCheckpointHandler from .trainer_callback import ( CallbackHandler, DefaultFlowCallback, @@ -144,6 +143,7 @@ speed_metrics, ) from .training_args import TrainingArguments +from .unified_checkpoint import UnifiedCheckpointHandler from .utils import reshard as reshard_util from .utils.async_save import AsyncSaver from .utils.helper import ( # nested_truncate, @@ -598,7 +598,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): if use_unified_checkpoint: self.unified_checkpoint_handler.load_unified_checkpoint( self.model, - self.optimizer, resume_from_checkpoint, ) logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.") @@ -1241,7 +1240,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if self.args.unified_checkpoint: self.unified_checkpoint_handler.load_unified_checkpoint( self.model, - self.optimizer, self.state.best_model_checkpoint, ) if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: @@ -1289,7 +1287,6 @@ def _load_best_model_from_peft_checkpoint(self): if self.args.unified_checkpoint: self.unified_checkpoint_handler.load_unified_checkpoint( self.model, - self.optimizer, self.state.best_model_checkpoint, ) if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: @@ -2775,7 +2772,6 @@ def _load_optimizer_and_scheduler(self, checkpoint): opt_state_dict = None else: opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( - args=self.args, model=self.model, optimizer=self.optimizer, resume_from_checkpoint=checkpoint, diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 567349dfde1c..569d188b2772 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1402,6 +1402,12 @@ def is_segment_parallel_supported(): f"but got logging_steps={self.logging_steps}." ) + if "split_param" in sharding_parallel_config: + assert self.sharding == [ShardingOption.SHARD_OP], "Only sharding stage1 support split_param." + assert ( + self.amp_master_grad + ), "If `split_param` in sharding_parallel_config, `amp_master_grad` must be True." + fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) diff --git a/paddlenlp/trainer/unified_checkpoint/__init__.py b/paddlenlp/trainer/unified_checkpoint/__init__.py new file mode 100644 index 000000000000..20a336cb3d8f --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unified_checkpoint import UnifiedCheckpointHandler diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py new file mode 100644 index 000000000000..4206821b50e5 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -0,0 +1,250 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Asynchronous unified checkpoint handler.""" + +import multiprocessing +import os +import time +from multiprocessing import shared_memory + +import paddle +import paddle.distributed as dist + +from paddlenlp.transformers.utils import is_safetensors_available +from paddlenlp.utils.log import logger + +if is_safetensors_available(): + from safetensors.numpy import save_file as safe_save_file + +from .shared_memory_utils import ( + _read_state_dict_from_shm, + _traverse_copy_to_shm, + create_meta_dict, +) + +__all__ = ["AsyncCheckpointHandler"] + + +class AsyncCheckpointHandler: + def __init__(self, args): + # Mainly for asynchronous saving. + self.args = args + self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + + self._shm_model_weight = None + self._shm_master_weight = None + self._shm_optimizer_weight = None + self._meta_dict_model = None + self._meta_dict_master_weight = None + self._meta_dict_optim = None + self._process_model_weight = None + self._process_master_weight = None + self._process_optimizer_weight = None + self._lock = None + self._shared_save_model_flag = None + self._shared_save_master_weight_flag = None + self._shared_save_optimizer_flag = None + + if "async_save" in self.args.unified_checkpoint_config: + self._lock = multiprocessing.Lock() + self._shared_save_model_path = multiprocessing.Array("c", 100000) + self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_model_flag = multiprocessing.Array("i", 1) + self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) + self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) + + def _file_save_async_or_sync( + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + ): + if is_sync: + for k in list(state_dict.keys()): + if isinstance(state_dict[k], paddle.Tensor): + state_dict[k] = state_dict.pop(k).cpu().numpy() + safe_save_file(state_dict, path, metadata={"format": "np"}) + else: + if state_dict_type == "model_weight": + if self._shm_model_weight is None: + self._meta_dict_model, buffer_size = create_meta_dict(state_dict) + self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_model_weight + meta_dict = self._meta_dict_model + shared_save_flag = self._shared_save_model_flag + shared_save_path = self._shared_save_model_path + shared_save_signal_path = self._shared_save_model_signal_path + if self._process_model_weight is None: + self._process_model_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_model_weight.name, + self._shared_save_model_flag, + self._shared_save_model_path, + self._shared_save_model_signal_path, + self._lock, + state_dict_type, + self.global_rank, + ), + ) + self._process_model_weight.start() + process = self._process_model_weight + elif state_dict_type == "master_weight": + if self._shm_master_weight is None: + self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) + self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_master_weight + meta_dict = self._meta_dict_master_weight + shared_save_flag = self._shared_save_master_weight_flag + shared_save_path = self._shared_save_master_weight_path + shared_save_signal_path = self._shared_save_master_weight_signal_path + if self._process_master_weight is None: + self._process_master_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_master_weight.name, + self._shared_save_master_weight_flag, + self._shared_save_master_weight_path, + self._shared_save_master_weight_signal_path, + self._lock, + "model_weight" + if "skip_save_model_weight" in self.args.unified_checkpoint_config + else state_dict_type, + self.global_rank, + ), + ) + self._process_master_weight.start() + process = self._process_master_weight + elif state_dict_type == "optimizer_weight": + if self._shm_optimizer_weight is None: + self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) + self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_optimizer_weight + meta_dict = self._meta_dict_optim + shared_save_flag = self._shared_save_optimizer_flag + shared_save_path = self._shared_save_optimizer_path + shared_save_signal_path = self._shared_save_optimizer_signal_path + if self._process_optimizer_weight is None: + self._process_optimizer_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_optimizer_weight.name, + self._shared_save_optimizer_flag, + self._shared_save_optimizer_path, + self._shared_save_optimizer_signal_path, + self._lock, + state_dict_type, + self.global_rank, + ), + ) + self._process_optimizer_weight.start() + process = self._process_optimizer_weight + + while True: # wait until no process is saving. + flag_value = shared_save_flag[0] + if flag_value == 0: + break + if not process.is_alive(): + raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") + time.sleep(0.5) + logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") + # only save model weight or save master weight, we enter this loop. + self._reset_and_update(shared_save_path, path) + self._reset_and_update(shared_save_signal_path, signal_path) + _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) + with self._lock: + shared_save_flag[0] = 1 + + def _save_file_async_in_process( + self, + meta_dict, + shm_name, + shared_save_flag, + shared_save_path, + shared_save_signal_path, + lock, + state_dict_type, + global_rank, + ): + shm = shared_memory.SharedMemory(name=shm_name) + while True: + flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value. + if flag_value == -1: # stop process + break + if flag_value == 0: # nothing to save + continue + if flag_value == 1: # need to save + path = shared_save_path[:].decode("utf-8").rstrip("\x00") + signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") + logger.info(f"Start to async save {path}") + state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array + safe_save_file(state_dict, path, {"format": "np"}) + del state_dict + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") + paddle.save(global_rank, saved_signal_path) + with lock: + shared_save_flag[0] = 0 + time.sleep(0.5) + shm.close() + + def _reset_and_update(self, shared_array, new_value): + # clear array + for i in range(len(shared_array)): + shared_array[i] = b"\0" + # update array + encoded_value = new_value.encode("utf-8") + shared_array[: len(encoded_value)] = encoded_value + + def unlink_shared_memory(self): + if not ("async_save" in self.args.unified_checkpoint_config): + return + + if self._shared_save_model_flag is not None: + while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_model_flag[0] = -1 + if self._shared_save_master_weight_flag is not None: + while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_master_weight_flag[0] = -1 + if self._shared_save_optimizer_flag is not None: + while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_optimizer_flag[0] = -1 + + if self._shm_model_weight is not None: + self._shm_model_weight.close() + self._shm_model_weight.unlink() + self._shm_model_weight = None + if self._shm_master_weight is not None: + self._shm_master_weight.close() + self._shm_master_weight.unlink() + self._shm_master_weight = None + if self._shm_optimizer_weight is not None: + self._shm_optimizer_weight.close() + self._shm_optimizer_weight.unlink() + self._shm_optimizer_weight = None + + if paddle.distributed.get_world_size() > 1: + dist.barrier() diff --git a/paddlenlp/trainer/unified_checkpoint/check_completion.py b/paddlenlp/trainer/unified_checkpoint/check_completion.py new file mode 100644 index 000000000000..cf337c468463 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/check_completion.py @@ -0,0 +1,249 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unfied checkpoint check functions.""" + +import json +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_OPTIMIZER_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import flatten_list + +try: + from paddle.base import core +except: + core = None + +from .utils import ( + get_expected_state_dict, + is_sharding_split_param_mode, + select_model_weight_index, + update_master_weight_status, +) + +__all__ = ["check_unified_checkpoint", "check_unified_optimizer"] + + +def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) + index_filename = os.path.join(resume_from_checkpoint, index_filename) + # Find index json file and distribute this file in global group. + if distributed_isfile(index_filename): + distributed_file(index_filename) + else: + raise Exception( + f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." + ) + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + all_weight_filenames = sorted(set(index["weight_map"].values())) + + # Get existed weight file list on current machine. + existed_filelist = [] + existed_files = [] + for filename in os.listdir(resume_from_checkpoint): + if filename in all_weight_filenames: + existed_files.append(filename) + + # Gather all the existed files in global group. + dist.all_gather_object(existed_filelist, existed_files) + flatten_existed_filelist = flatten_list(existed_filelist) + diff_filelist = list(set(all_weight_filenames).difference(set(flatten_existed_filelist))) + if len(diff_filelist) != 0: + raise Exception(f"Sorry, the weight file list on the machines is not complete!, missing {diff_filelist}") + + # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. + local_resume = True + if args.dataset_rank == 0 or args.use_expert_parallel: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + need_files = set() + state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): + filename = index["weight_map"][key] + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + need_files.add(filename) + diff_filelist = list(need_files.difference(set(existed_files))) + num_diff = paddle.to_tensor([len(diff_filelist)]) + if tp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) + if num_diff.item() == 0: + local_resume = True + else: + local_resume = False + local_resume = paddle.to_tensor([local_resume]) + dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) + local_resume = local_resume.item() + return local_resume + + +def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + if not safe_serialization: + index_filename, index_filename_master_weights = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME + else: + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + index_filename = os.path.join(resume_from_checkpoint, index_filename) + index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) + + # Find index json file and distribute the file in global group. + if distributed_isfile(index_filename): + distributed_file(index_filename) + else: + raise Exception( + f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." + ) + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + all_optimizer_filenames = sorted(set(index["weight_map"].values())) + + has_master_weights = index["master_weights"] + # update has_master_weights and index_filename_master_weights + # 1. if the master weight exists, only has_master_weights is set True and loaded when needed + # 2. if master weight does not exist, convert model weight to master weight when needed + has_master_weights, index_filename_master_weights = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization + ) + if has_master_weights: + index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) + if distributed_isfile(index_filename_master_weights): + distributed_file(index_filename_master_weights) + else: + raise Exception( + f"Sorry, we can not find {index_filename_master_weights}. This file should be appear at least on one machine." + ) + with open(index_filename_master_weights, "r") as f: + index_mw = json.loads(f.read()) + all_mw_filenames = sorted(set(index_mw["weight_map"].values())) + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + sharding_rank = sharding_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} + + if is_sharding_split_param_mode(args): + # We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume. + logger.warning("We only support local resume for split_param mode, do not support dynamically loading.") + return True + + if sharding_group.nranks > 1: + param2rank = optimizer._param2rank + + def check_complete(all_filenames): + # Check whether the checkpoint files on machines are complete. If not complete, raise Exception. + existed_filelist = [] + existed_files = [] + for filename in os.listdir(resume_from_checkpoint): + if filename in all_filenames: + existed_files.append(filename) + + dist.all_gather_object(existed_filelist, existed_files) + flatten_existed_filelist = flatten_list(existed_filelist) + diff_filelist = list(set(all_filenames).difference(set(flatten_existed_filelist))) + if len(diff_filelist) != 0: + raise Exception( + f"Sorry, the optimizer file list on `data_parallel_rank==0` machines is not complete!, missing {diff_filelist}" + ) + return existed_files + + def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): + # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. + local_resume = True + if args.data_parallel_rank == 0 or args.use_expert_parallel: + need_files = set() + state_dict = get_expected_state_dict(model) + + for key in state_dict.keys(): + if sharding_group.nranks > 1: + static_name = struct2static_name_mappings.get(key, None) + param_rank = param2rank.get(static_name, None) + if param_rank != sharding_rank: + continue + + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + + if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: + continue + + if not is_master_weights: + for type_name in typename_set: + type_key = key + "/" + type_name + filename = weight_map[type_key] + need_files.add(filename) + else: + filename = weight_map[key] + need_files.add(filename) + + diff_filelist = list(need_files.difference(set(existed_files))) + num_diff = paddle.to_tensor([len(diff_filelist)]) + if tp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if sharding_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) + + if num_diff.item() == 0: + local_resume = True + else: + local_resume = False + local_resume = paddle.to_tensor([local_resume]) + dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) + return local_resume.item() + + # check whether the optimizer checkpoint files are complete. + existed_files = check_complete(all_optimizer_filenames) + if has_master_weights: + existed_files_mw = check_complete(all_mw_filenames) + # get optimizer's param type name, like moment1_0. + typename_set = set() + for key in index["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + local_resume = check_dynamic_load( + args, index["weight_map"], existed_files, is_master_weights=False, typename_set=typename_set + ) + local_resume_rw = True + if has_master_weights: + local_resume_rw = check_dynamic_load(args, index_mw["weight_map"], existed_files_mw, is_master_weights=True) + return local_resume & local_resume_rw diff --git a/paddlenlp/trainer/unified_checkpoint/load_dynamic.py b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py new file mode 100644 index 000000000000..064ecacc7c3c --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py @@ -0,0 +1,495 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unified Checkpoint Dynamic Loading Functions""" + +import copy +import json +import os +import sys + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import _load_state_dict_into_model +from paddlenlp.transformers.utils import device_guard, is_safetensors_available +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_OPTIMIZER_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +if is_safetensors_available(): + if sys.platform.startswith("win"): + from safetensors import safe_open + else: + from paddlenlp.utils.safetensors import fast_safe_open as safe_open + +from .utils import ( + FP32_MASTER, + get_expected_state_dict, + mapping_optimizer_tp_actions, + optimizer_non_scaler_name, + optimizer_scalar_name, + select_model_weight_index, + update_master_weight_status, +) + +__all__ = ["load_unified_checkpoint_dynamically", "load_unified_optimizer_dynamically"] + + +def create_send_table(file_keyname_mappings, file_machine_mappings): + send_table = {} + global_rank = dist.get_rank() + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) + for filename, keys in file_keyname_mappings.items(): + machine = file_machine_mappings[filename][0] + is_src = (global_rank // local_device_count) == machine + for i, key in enumerate(keys): + if is_src and local_rank == i % local_device_count: + send_table[key] = global_rank + dispatch_list = [] + dist.all_gather_object(dispatch_list, send_table) + send_table = {} + for dl in dispatch_list: + send_table.update(dl) + return send_table + + +def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappings): + """Create dispatch table for dynamically loading state dict. + + Args: + args + """ + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + tp_rank = tp_group.rank + + # Create tensor receive table, contains {"key0": [global_rank, tp_rank], "key1": [global_rank, tp_rank]} + dispatch_list = [] + recv_table = {} + if args.dataset_rank == 0: + state_dict = get_expected_state_dict(model) + for (k, v) in state_dict.items(): + if hasattr(v, "is_distributed") and v.is_distributed: + recv_table[k] = [(dist.get_rank(), tp_rank)] + else: + recv_table[k] = [(dist.get_rank(), -1)] + + # Gather receive table in global group. + dist.all_gather_object(dispatch_list, recv_table) + recv_table = {} + for dl in dispatch_list: + for key, value in dl.items(): + if key not in recv_table: + recv_table[key] = value + else: + recv_table[key] += value + + # Create send table, to decide which worker to send the key. Contains {"key0:" global_rank, "key1": global_rank, ...} + send_table = create_send_table(file_keyname_mappings, file_machine_mappings) + + return send_table, recv_table + + +def create_optimizer_dispatch_table( + args, + model, + optimizer, + file_keyname_mappings, + file_machine_mappings, + struct2static_name_mappings, + is_master_weights=False, + typename_set=None, +): + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + sharding_rank = sharding_group.rank + if sharding_group.nranks > 1: + param2rank = optimizer._param2rank + tp_rank = tp_group.rank + + # Create receive table, contains {"param_key0": [global_rank, tp_rank], "param_key1": [global_rank, tp_rank]} + dispatch_list = [] + recv_table = {} + if args.data_parallel_rank == 0: + state_dict = get_expected_state_dict(model) + for (k, v) in state_dict.items(): + if sharding_group.nranks > 1: + static_name = struct2static_name_mappings[k] + param_rank = param2rank.get(static_name, None) + if param_rank != sharding_rank: + continue + if is_master_weights: + if hasattr(v, "is_distributed") and v.is_distributed: + recv_table[k] = [(dist.get_rank(), tp_rank)] + else: + recv_table[k] = [(dist.get_rank(), -1)] + else: + for typename in typename_set: + type_key = k + "/" + typename + if typename in optimizer_non_scaler_name: + if hasattr(v, "is_distributed") and v.is_distributed: + recv_table[type_key] = [(dist.get_rank(), tp_rank)] + else: + recv_table[type_key] = [(dist.get_rank(), -1)] + else: + recv_table[type_key] = [(dist.get_rank(), -1)] + + dist.all_gather_object(dispatch_list, recv_table) + recv_table = {} + for dl in dispatch_list: + for k, v in dl.items(): + if k not in recv_table: + recv_table[k] = v + else: + recv_table[k] += v + + # Create send table, to decide which worker to send the key. Contains {"param_key0:" 0, "param_key1": 1, ...} + send_table = create_send_table(file_keyname_mappings, file_machine_mappings) + return send_table, recv_table + + +def get_file_mappings(index, resume_from_checkpoint): + file_keyname_mappings = {} + for k, v in index["weight_map"].items(): + if v not in file_keyname_mappings: + file_keyname_mappings[v] = [] + file_keyname_mappings[v].append(k) + for k in file_keyname_mappings.keys(): + file_keyname_mappings[k] = sorted(file_keyname_mappings[k]) + + local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + global_rank = dist.get_rank() + file_machine_mappings = {} + for filename in file_keyname_mappings.keys(): + if local_rank == 0 and os.path.exists(os.path.join(resume_from_checkpoint, filename)): + file_machine_mappings[filename] = [global_rank // local_device_count] + file_machine_list = [] + dist.all_gather_object(file_machine_list, file_machine_mappings) + file_machine_mappings = {} + for mappings in file_machine_list: + for k, v in mappings.items(): + if k not in file_machine_mappings: + file_machine_mappings[k] = v + else: + file_machine_mappings[k] += v + return file_keyname_mappings, file_machine_mappings + + +def distributed_send_recv( + state_dict, + tp_actions, + send_table, + recv_table, + resume_from_checkpoint, + file_keyname_mappings, + file_machine_mappings, +): + + local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) + global_rank = dist.get_rank() + for filename in file_keyname_mappings.keys(): + machine = file_machine_mappings[filename][0] + is_src = global_rank // local_device_count == machine + if is_src: + f = safe_open(os.path.join(resume_from_checkpoint, filename), framework="np") + + for key in file_keyname_mappings[filename]: + recv_info = recv_table[key] + recv_ranklist = [a for (a, _) in recv_info] + if is_src and global_rank == send_table[key]: + py_safe_slice_ = f.get_slice(key) + # send + if key in tp_actions: + weight = tp_actions[key](py_safe_slice_) + # copy weight to GPU + for j in range(len(weight)): + with device_guard(): + weight[j] = paddle.Tensor(weight[j], zero_copy=True) + weight[j] = weight[j]._copy_to(paddle.framework._current_expected_place(), False) + + for recv_rank, split_index in recv_info: + if recv_rank == global_rank: + state_dict[key] = weight[split_index] + else: + dist.stream.send(weight[split_index], dst=recv_rank) + else: + # no need to tp split + weight = py_safe_slice_[:] + with device_guard(): + weight = paddle.Tensor(weight, zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + for recv_rank, _ in recv_info: + if recv_rank == global_rank: + state_dict[key] = weight + else: + dist.stream.send(weight, dst=recv_rank) + + if global_rank != send_table[key] and global_rank in recv_ranklist: + dist.stream.recv(state_dict[key], src=send_table[key]) + + if is_src: + f.__exit__(None, None, None) + + return state_dict + + +def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) + index_filename = os.path.join(resume_from_checkpoint, index_filename) + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + # `file_keyname_mappings` indicates which keys each file contains. For example, {"model-00001-of-00002.safetensors": ["llama.embed_tokens.weight", "llama.layers.0.self_attn.q_proj.weight", ...]} + # `file_machine_mappings` indicates the machine where the files appear. For example, {"model-00001-of-00002.safetensors": [machine_0, machine_1], "model-00002-of-00002.safetensors": [machine_0]} + file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) + + logger.debug("Creating dispatch table for unified checkpoint load ...") + # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. + send_table, recv_table = create_dispatch_table( + args, + model, + file_keyname_mappings, + file_machine_mappings, + ) + + # Get all the keys that are splited by tensor parallelism. + all_tp_keys = set() + for k, v in recv_table.items(): + if v[0][1] != -1: + all_tp_keys.add(k) + + config_revise = copy.deepcopy(model.config) + config_revise.tensor_parallel_rank = None + if len(all_tp_keys) == 0: + tp_actions = {} + else: + # Get corresponding tensor parallel actions. + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) + + logger.debug("Distributed send recv for state dict load ...") + # Distribute the checkpoint tensor dynamically, using the `send_table` and `recv_table` we create before. + state_dict = distributed_send_recv( + get_expected_state_dict(model), + tp_actions, + send_table, + recv_table, + resume_from_checkpoint, + file_keyname_mappings, + file_machine_mappings, + ) + dist.barrier() + logger.debug("Setting state dict into model ...") + error_msgs = _load_state_dict_into_model(model, state_dict, "") + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + raise RuntimeError(f"Error(s) in loading dynamic state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + +def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + optim_state_dict = nested_copy(optimizer.state_dict()) + if "master_weights" in optim_state_dict.keys(): + optim_state_dict.pop("master_weights") + + if safe_serialization: + index_filename, index_filename_mw = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + else: + index_filename, index_filename_mw = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME + + with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: + index = json.loads(f.read()) + + # `file_keyname_mappings` indicates which keys each file contains. For example, {"optimizer-00001-of-00002.safetensors": ["llama.embed_tokens.weight/moment1_0", "llama.layers.1.mlp.gate_proj.weight/moment1_0", ...]} + # `file_machine_mappings` indicates the machine where the files appear. For example, {"optimizer-00001-of-00002.safetensors": [machine_0, machine_1], "optimizer-00002-of-00002.safetensors": [machine_0]} + file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) + + has_master_weights = index["master_weights"] + # update has_master_weights and index_filename_master_weights + # 1. if the master weights exists, only has_master_weights is set True and load master weights when needed + # 2. if master weights does not exist, convert model weights to master weights when needed + has_master_weights, index_filename_mw = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization + ) + + if has_master_weights: + with open(os.path.join(resume_from_checkpoint, index_filename_mw), "r") as f: + index_mw = json.loads(f.read()) + file_keyname_mappings_mw, file_machine_mappings_mw = get_file_mappings(index_mw, resume_from_checkpoint) + + # Get optimizer param type name, like moment1_0, moment2_0, beta1_pow_acc_0. + typename_set = set() + for key in index["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} + # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. + send_table, recv_table = create_optimizer_dispatch_table( + args, + model, + optimizer, + file_keyname_mappings, + file_machine_mappings, + struct2static_name_mappings, + is_master_weights=False, + typename_set=typename_set, + ) + if has_master_weights: + send_table_mw, recv_table_mw = create_optimizer_dispatch_table( + args, + model, + optimizer, + file_keyname_mappings_mw, + file_machine_mappings_mw, + struct2static_name_mappings, + is_master_weights=True, + ) + + # Initialize optimizer state dict. + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + if sharding_group.nranks > 1: + param2rank = optimizer._param2rank + optim_state_dict_mw = {} + + def check_optimizer_param(parameter): + if sharding_group.nranks > 1: + param_rank = param2rank.get(parameter.name, None) + if param_rank != sharding_group.rank: + return False + if parameter.stop_gradient: + return False + return True + + optimizer_keys_with_shape = [] + if isinstance(optimizer._parameter_list[0], dict): + for param_group in optimizer._parameter_list: + # If parameter groups are set, there must be `params` key. This is guaranteed by the optimizer's initialization code. + for parameter in param_group["params"]: + if check_optimizer_param(parameter): + optimizer_keys_with_shape.append((parameter.name, parameter.shape)) + else: + for parameter in optimizer._parameter_list: + if check_optimizer_param(parameter): + optimizer_keys_with_shape.append((parameter.name, parameter.shape)) + + # see how to change + for static_name, shape in optimizer_keys_with_shape: + k = static2struct_name_mappings[static_name] + for typename in typename_set: + new_k = k + "/" + typename + if typename in optimizer_scalar_name: + optim_state_dict[new_k] = paddle.empty([1], dtype="float32") + else: + optim_state_dict[new_k] = paddle.empty(shape, dtype="float32") + if has_master_weights: + optim_state_dict_mw[k] = paddle.empty(shape, dtype="float32") + + # Get all the keys that are splited by tensor parallelism. + all_tp_keys = set() + for k, v in recv_table.items(): + structure_name, typename = k.split("/") + if typename in optimizer_non_scaler_name: + if v[0][1] != -1: + all_tp_keys.add(structure_name) + + # Get corresponding tensor parallel actions. + config_revise = copy.deepcopy(model.config) + config_revise.tensor_parallel_rank = None + if len(all_tp_keys) == 0: + tp_actions = {} + else: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) + optimizer_keys = list(index["weight_map"].keys()) + optimizer_tp_actions = mapping_optimizer_tp_actions(tp_actions, optimizer_keys) + if has_master_weights: + optimizer_tp_actions.update(tp_actions) + + # Distribute the optimizer checkpoint dynamically, using the `send_table` and `recv_table` we create before. + optim_state_dict = distributed_send_recv( + optim_state_dict, + optimizer_tp_actions, + send_table, + recv_table, + resume_from_checkpoint, + file_keyname_mappings, + file_machine_mappings, + ) + dist.barrier() + if has_master_weights: + optim_state_dict_mw = distributed_send_recv( + optim_state_dict_mw, + optimizer_tp_actions, + send_table_mw, + recv_table_mw, + resume_from_checkpoint, + file_keyname_mappings_mw, + file_machine_mappings_mw, + ) + dist.barrier() + + # Rename optimizer state dict. + for key in list(optim_state_dict.keys()): + if key == "LR_Scheduler": + continue + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + optim_state_dict[key_name] = optim_state_dict.pop(key) + optim_state_dict[key_name].name = key_name + + if has_master_weights: + optim_state_dict["master_weights"] = {} + for key in list(optim_state_dict_mw.keys()): + static_name = struct2static_name_mappings[key] + optim_state_dict["master_weights"][static_name] = optim_state_dict_mw.pop(key) + optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + if args.data_parallel_rank == 0: + return optim_state_dict + return None diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py new file mode 100644 index 000000000000..552289d8f383 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unfied checkpoint locally loading functions.""" + +import gc +import os + +from tqdm.auto import tqdm + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import ( + _load_state_dict_into_model, + faster_set_state_dict, + load_state_dict, +) +from paddlenlp.transformers.utils import get_checkpoint_shard_files +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_OPTIMIZER_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +from .sharding_split_param_utils import load_unified_optimizer_split_param +from .utils import ( + FP32_MASTER, + get_expected_keys, + get_expected_state_dict, + get_optimizer_shard_files, + is_sharding_split_param_mode, + mapping_optimizer_tp_actions, + select_model_weight_index, + update_master_weight_status, +) + +__all__ = ["load_unified_checkpoint_locally", "load_unified_optimizer_locally"] + + +def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): + """ + Only dataset_rank == 0 or using expert parallel can enter this function. + """ + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=True) + + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + loaded_keys = sharded_metadata["all_checkpoint_keys"] + + model_state_dict = get_expected_state_dict(model) + # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. + if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): + expected_keys = set(list(model_state_dict.keys())) + else: + expected_keys = set() + for key in model_state_dict.keys(): + if getattr(model_state_dict[key], "no_sync", False): + expected_keys.add(key) + missing_keys = expected_keys - set(loaded_keys) + + use_fast_set = True + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + use_fast_set = False + + if len(missing_keys) > 0: + raise ValueError(f"missing_keys: {missing_keys}") + + def _remove_unused_keys( + state_dict, + model_state_dict, + ): + unused_keys = set(state_dict.keys()) - set(model_state_dict.keys()) + for unused_key in unused_keys: + del state_dict[unused_key] + return unused_keys + + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading checkpoint shards") + + for shard_file in resolved_archive_file: + # TODO: check if no expected_keys in shard_file, then don't load it + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + + pre_tensor_parallel_split = False + if shard_file.endswith(".safetensors") and model.config.tensor_parallel_degree > 1: + pre_tensor_parallel_split = True + assert loaded_keys is not None, "loaded_keys is not None." + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(loaded_keys), is_split=True, ignore_error=True + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True) + # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors + state_dict = load_state_dict( + shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected" + ) + + if not pre_tensor_parallel_split: + # Since we load all keys but we only need one of pipeline stages + _ = _remove_unused_keys(state_dict, model_state_dict) + + if model.config.tensor_parallel_degree > 1 and not pre_tensor_parallel_split: + logger.info("Converting state_dict to Tensor Parallel Format") + # ignore error for multi shard, since only parts of data + state_dict = model.convert_tensor_parallel( + None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1 + ) + + if use_fast_set: + error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False) + else: + error_msgs += _load_state_dict_into_model(model, state_dict, "") + + # force memory release + del state_dict + # gc.collect() + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if " but the expected shape is" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + +def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + # Special process with split param. + if is_sharding_split_param_mode(args): + returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint) + return returned_optim_state_dict + + # init and get optimizer LR_Scheduler + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + if not safe_serialization: + index_filename, index_filename_master_weights = ( + PADDLE_OPTIMIZER_INDEX_NAME, + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + ) + else: + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + model_state_dict = get_expected_state_dict(model) + model_keys = list(model_state_dict.keys()) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings + + expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) + + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + + # update has_master_weights and index_filename_master_weights + # 1. if the master weight exists, only has_master_weights is set True and loaded when needed + # 2. if master weight does not exist, convert model weight to master weight when needed + has_master_weights, index_filename_master_weights = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization + ) + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), + ) + + expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) + if not isinstance(resolved_archive_file_mw, list): + resolved_archive_file_mw = [resolved_archive_file_mw] + if len(resolved_archive_file_mw) > 1: + resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") + + def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + returned_state_dict = {} + # load optimizer + for shard_file in resolved_archive_file: + # TODO: check if no expected_keys in shard_file, then don't load it + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + + if shard_file.endswith(".safetensors"): + # assert model_keys is not None, "model_keys is None." TODO: correct the assert + if model.config.tensor_parallel_degree > 1: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + model_keys, is_split=True, ignore_error=True + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions( + model.config, model_keys, ignore_error=True + ) + if not is_master_weights: + tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) + + # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors + state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + else: + # for pipeline model, we don't need to use tp_actions + state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") + + returned_state_dict.update(state_dict) + # force memory release + del state_dict + gc.collect() + return returned_state_dict + + state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) + if has_master_weights: + state_dict_master_weight = load_resolved_archive_file( + resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True + ) + # rename optimizer param + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + + if has_master_weights: + for key in list(state_dict_master_weight.keys()): + static_name = struct2static_name_mappings[key] + returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + return returned_optim_state_dict diff --git a/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py new file mode 100644 index 000000000000..c8d514dda55f --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py @@ -0,0 +1,240 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Save and load single card checkpoint for Unified Checkpoint""" + +import gc +import json +import os + +import paddle + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import ( + _load_state_dict_into_model, + load_state_dict, +) +from paddlenlp.transformers.utils import ( + dtype_byte_size, + get_checkpoint_shard_files, + is_safetensors_available, +) +from paddlenlp.utils.env import ( + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +if is_safetensors_available(): + from safetensors.numpy import save_file as safe_save_file + +from .utils import ( + FP32_MASTER, + generate_base_static_name, + get_expected_state_dict, + get_optimizer_shard_files, + save_model_config, +) + +__all__ = [ + "load_single_card_checkpoint", + "load_single_card_optimizer", + "save_single_card_checkpoint", + "save_single_card_optimizer", +] + + +def save_file_sync(state_dict, path): + for k in list(state_dict.keys()): + if isinstance(state_dict[k], paddle.Tensor): + state_dict[k] = state_dict.pop(k).cpu().numpy() + safe_save_file(state_dict, path, metadata={"format": "np"}) + + +def save_single_card_checkpoint(model_to_save, output_dir): + """Save checkpoint for non-distributed environment.""" + + state_dict = get_expected_state_dict(model_to_save) + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + weight_filename = "peft_model-00001-of-00001.safetensors" + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + weight_filename = "model-00001-of-00001.safetensors" + index_filename = SAFE_WEIGHTS_INDEX_NAME + # get index json + index_weight_file = {} + total_size = 0 + for key, weight in state_dict.items(): + index_weight_file[key] = weight_filename + total_size += weight.numel().item() * dtype_byte_size(weight.dtype) + sharded_index_json = {} + sharded_index_json["metadata"] = {"total_size": total_size} + sharded_index_json["weight_map"] = index_weight_file + if isinstance(model_to_save, LoRAModel): + sharded_index_json["type"] = "lora" + elif isinstance(model_to_save, PrefixModelForCausalLM): + sharded_index_json["type"] = "ptuning" + + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, index_filename) + with open(path, "w") as f: + json.dump(sharded_index_json, f, indent=4) + + # save checkpoint, do no support asynchronous save for single card currently. + logger.warning("Asynchronous saving is not supported for single card environment currently.") + save_file_sync(state_dict, path=os.path.join(output_dir, weight_filename)) + + save_model_config(model_to_save, output_dir) + + +def save_single_card_optimizer(model, optimizer, output_dir): + """ "Save optimizer for non-distributed environment.""" + # Split into optimizer params and master weights. + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + static2struct_name_mappings = {} + state_dict = get_expected_state_dict(model) + fp32_weight = {} + for k, v in state_dict.items(): + static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + fp32_weight[k] = v + + # rename optimizer param + for key in list(optim_state_dict.keys()): + static_name, type_name = generate_base_static_name(key) + new_name = static2struct_name_mappings[static_name] + "/" + type_name + optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: + for key in list(master_weights.keys()): + master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) + + # save index json + index_optimizer_file, index_master_weight_file = {}, {} + total_optim_size, total_master_weight_size = 0, 0 + for key, weight in optim_state_dict.items(): + index_optimizer_file[key] = "optimizer-00001-of-00001.safetensors" + total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) + if master_weights is not None: + for key, weight in master_weights.items(): + index_master_weight_file[key] = "master_weights-00001-of-00001.safetensors" + total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) + path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME) + master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME) + with open(path, "w") as f: + has_master_weights = master_weights is not None + json.dump( + { + "metadata": {"total_size": total_optim_size}, + "weight_map": index_optimizer_file, + "master_weights": has_master_weights, + }, + f, + indent=4, + ) + if master_weights is not None: + with open(master_path, "w") as f: + json.dump( + {"metadata": {"total_size": total_master_weight_size}, "weight_map": index_master_weight_file}, + f, + indent=4, + ) + + # save optimizer state dict + save_file_sync(optim_state_dict, path=os.path.join(output_dir, "optimizer-00001-of-00001.safetensors")) + if master_weights is not None: + save_file_sync(master_weights, path=os.path.join(output_dir, "master_weights-00001-of-00001.safetensors")) + + +def load_single_card_checkpoint(model, resume_from_checkpoint: str): + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + index_filename = SAFE_WEIGHTS_INDEX_NAME + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + + loaded_keys = sharded_metadata["all_checkpoint_keys"] + model_state_dict = get_expected_state_dict(model) + expected_keys = set(list(model_state_dict.keys())) + missing_keys = expected_keys - set(loaded_keys) + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys: {missing_keys}") + + state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys) + error_msgs = _load_state_dict_into_model(model, state_dict, "") + del state_dict + gc.collect() + + if error_msgs: + raise RuntimeError(f"Error(s) in loading state dict for {model.__class__.__name__}:\n\t{error_msgs}") + + +def load_single_card_optimizer(model, optimizer, resume_from_checkpoint: str): + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + expected_keys = sharded_metadata["all_optimizer_keys"] + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, SAFE_MASTER_WEIGHTS_INDEX_NAME), + ) + expected_keys_mw = sharded_metadata_mw["all_optimizer_keys"] + + state_dict_optim = load_state_dict(resolved_archive_file[0], None, expected_keys) + if has_master_weights: + state_dict_optim_mw = load_state_dict(resolved_archive_file_mw[0], None, expected_keys_mw) + + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + if has_master_weights: + for key in list(state_dict_optim_mw.keys()): + static_name = struct2static_name_mappings[key] + returned_optim_state_dict["master_weights"][static_name] = state_dict_optim_mw.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + return returned_optim_state_dict diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py new file mode 100644 index 000000000000..f337b1a8186b --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -0,0 +1,300 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Support Sharding Stage1 V2(split param) for Unified Checkpoint""" + +import gc +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from tqdm.auto import tqdm + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import load_state_dict +from paddlenlp.utils.env import ( + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.nested import nested_copy + +from .utils import ( + FP32_MASTER, + generate_base_static_name, + get_expected_state_dict, + get_optimizer_shard_files, + mapping_optimizer_tp_actions, +) + +__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"] + + +def merge_splited_param( + state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False +): + """Merge the splited param in sharding group.""" + global_rank = dist.get_rank() + for key in list(state_dict.keys()): + if state_dict[key].numel().item() == 1: # for example: beta1, beta2 + continue + + static_name = key if is_master_weights else generate_base_static_name(key)[0] + shape, numel, index, padded_size = param_shape_info[static_name] + if static_name not in partial_tensor_list: + state_dict[key] = state_dict[key].reshape(shape) + continue + + recv_rank = recv_table[static_name] + send_info = send_table[static_name] + + base_padding_start = index + numel + base_padding_end = index + padded_size + + if global_rank == recv_rank: + tmp_tensor_list = [] + for send_rank, begin, end in send_info: + padding_start = max(begin, base_padding_start) + padding_end = min(end, base_padding_end) + + if send_rank == recv_rank: + tensor = ( + state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] + ) + tmp_tensor_list.append(tensor) + else: + length = end - begin if padding_start >= padding_end else padding_start - begin + tmp_tensor = paddle.empty(shape=[length], dtype=state_dict[key].dtype) + dist.stream.recv(tmp_tensor, src=send_rank) + tmp_tensor_list.append(tmp_tensor) + state_dict[key] = paddle.concat(tmp_tensor_list, axis=0).reshape(shape) + else: + for send_rank, begin, end in send_info: + padding_start = max(begin, base_padding_start) + padding_end = min(end, base_padding_end) + if global_rank == send_rank: + tensor = ( + state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] + ) + dist.stream.send(tensor, dst=recv_rank) + state_dict.pop(key) + return state_dict + + +def gather_splited_param_for_optimizer(optimizer): + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + global_rank = dist.get_rank() + param_slice_info = {} + param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: + for key in buffer._sharding_param_grad_view.keys(): + param_slice_info[key] = ( + buffer._sharding_param_grad_view[key]._param_begin, + buffer._sharding_param_grad_view[key]._param_end, + ) + param_shape_info[key] = ( + buffer._sharding_param_grad_view[key]._param.shape, + buffer._sharding_param_grad_view[key]._param.numel().item(), + buffer._sharding_param_grad_view[key]._index, + buffer._sharding_param_grad_view[key]._padded_size, + ) + param_slice_info["global_rank"] = global_rank + param_slice_info_list = [] + dist.all_gather_object(param_slice_info_list, param_slice_info, group=sharding_group) + + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + # deal with optimizer param + partial_tensor_list = [] + for key in list(optim_state_dict.keys()): + static_name, _ = generate_base_static_name(key) + if static_name in param_slice_info.keys(): + if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2 + continue + begin, end = param_slice_info[static_name] + shape, numel, _, _ = param_shape_info[static_name] + if end - begin == numel: # full tensor + optim_state_dict[key] = optim_state_dict[key].reshape(shape) + elif end <= begin: # empty tensor + continue + else: # partial tensor, end > begin but end - begin < numel + partial_tensor_list.append(static_name) + + send_table = {} + recv_table = {} + for key in partial_tensor_list: + sharding_ranklist = [] + for slice_info in param_slice_info_list: + begin, end = slice_info[key] + if end > begin: + sharding_ranklist.append((slice_info["global_rank"], begin, end)) + recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor + send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist] + + merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False) + if master_weights is not None: + merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True) + return optim_state_dict, master_weights + + +def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint): + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + typename_set = set() + for key in sharded_metadata["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + + model_state_dict = get_expected_state_dict(model) + model_keys = list(model_state_dict.keys()) + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + + expected_keys = [] + param_slice_info = {} + param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: + for key in buffer._sharding_param_grad_view.keys(): + begin = buffer._sharding_param_grad_view[key]._param_begin + end = buffer._sharding_param_grad_view[key]._param_end + if end > begin: + expected_keys.append(key) + shape = buffer._sharding_param_grad_view[key]._param.shape + numel = buffer._sharding_param_grad_view[key]._param.numel().item() + index = buffer._sharding_param_grad_view[key]._index + padded_size = buffer._sharding_param_grad_view[key]._padded_size + param_slice_info[key] = (begin, end) + param_shape_info[key] = (shape, numel, index, padded_size) + + expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys]) + expected_keys_optim = [] + for key in expected_keys: + for typename in typename_set: + expected_keys_optim.append(f"{key}/{typename}") + expected_keys_optim = set(expected_keys_optim) + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), + ) + if len(resolved_archive_file_mw) > 1: + resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") + + def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + returned_state_dict = {} + + if model.config.tensor_parallel_degree > 1: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=True, ignore_error=True) + else: + tp_actions = model.get_tensor_parallel_convert_actions(model.config, model_keys, ignore_error=True) + if not is_master_weights: + tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) + + for shard_file in resolved_archive_file: + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + if model.config.tensor_parallel_degree > 1: + state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu") + else: + state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu") + returned_state_dict.update(state_dict) + del state_dict + gc.collect() + + return returned_state_dict + + # get tp params + state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim) + + # need to split param for different sharding rank, maybe need to deal with oom issue. + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings.get(key_name[0], None) + + if state_dict_optim[key].numel().item() > 1: + begin, end = param_slice_info[static_name] + shape, numel, index, padded_size = param_shape_info[static_name] + state_dict_optim[key] = state_dict_optim[key].reshape([-1]) + state_dict_optim[key] = state_dict_optim[key][begin - index : end - index] + + padding_start = max(begin, index + numel) + padding_end = min(end, index + padded_size) + if padding_start < padding_end: + state_dict_optim[key] = paddle.concat( + ( + state_dict_optim[key], + paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype), + ) + ) + if has_master_weights: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + + state_dict_optim[key] = state_dict_optim[key]._copy_to(paddle.framework._current_expected_place(), False) + + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + + if has_master_weights: + state_dict_master_weight = load_resolved_archive_file( + resolved_archive_file_mw, + sharded_metadata_mw, + expected_keys, + is_master_weights=True, + ) + + for key in list(state_dict_master_weight.keys()): + static_name = struct2static_name_mappings.get(key, None) + if state_dict_master_weight[key].numel().item() > 1: + begin, end = param_slice_info[static_name] + shape, numel, index, padded_size = param_shape_info[static_name] + state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1]) + state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index] + + padding_start = max(begin, index + numel) + padding_end = min(end, index + padded_size) + if padding_start < padding_end: + state_dict_master_weight[key] = paddle.concat( + ( + state_dict_master_weight[key], + paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype), + ) + ) + state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to( + paddle.framework._current_expected_place(), False + ) + returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + return returned_optim_state_dict diff --git a/paddlenlp/trainer/plugins/shared_memory_utils.py b/paddlenlp/trainer/unified_checkpoint/shared_memory_utils.py similarity index 100% rename from paddlenlp/trainer/plugins/shared_memory_utils.py rename to paddlenlp/trainer/unified_checkpoint/shared_memory_utils.py diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py new file mode 100644 index 000000000000..d6c2db82f126 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -0,0 +1,651 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os +import sys + +import paddle +from paddle.distributed import fleet + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.trainer.argparser import strtobool +from paddlenlp.trainer.utils.helper import distributed_isfile +from paddlenlp.transformers.model_utils import ( + PretrainedModel, + _add_variant, + unwrap_model, +) +from paddlenlp.transformers.utils import ( + device_guard, + dtype_byte_size, + is_safetensors_available, +) +from paddlenlp.utils.env import ( + LORA_WEIGHTS_NAME, + PADDLE_MASTER_WEIGHTS_NAME, + PADDLE_OPTIMIZER_NAME, + PADDLE_WEIGHTS_NAME, + PREFIX_WEIGHTS_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_MASTER_WEIGHTS_NAME, + SAFE_OPTIMIZER_INDEX_NAME, + SAFE_OPTIMIZER_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, + SAFE_PEFT_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +if is_safetensors_available(): + if sys.platform.startswith("win"): + from safetensors.numpy import load_file + else: + from paddlenlp.utils.safetensors import fast_load_file as load_file + +from .async_handler import AsyncCheckpointHandler +from .check_completion import check_unified_checkpoint, check_unified_optimizer +from .load_dynamic import ( + load_unified_checkpoint_dynamically, + load_unified_optimizer_dynamically, +) +from .load_local import load_unified_checkpoint_locally, load_unified_optimizer_locally +from .load_save_single_card import ( + load_single_card_checkpoint, + load_single_card_optimizer, + save_single_card_checkpoint, + save_single_card_optimizer, +) +from .sharding_split_param_utils import gather_splited_param_for_optimizer +from .utils import ( + FP32_MASTER, + UnifiedCheckpointOption, + filter_params, + gather_sharded_object, + generate_base_static_name, + get_expected_state_dict, + get_sharded_file_name, + get_sharded_index, + is_need_master_weight, + is_sharding_split_param_mode, + merge_tensor_parallel_for_optimizer, + merge_tensor_parallel_with_shard, + reduce_master_weights_status, + rename_shard_file, + save_model_config, +) + +__all__ = ["UnifiedCheckpointHandler"] + + +class UnifiedCheckpointHandler: + def __init__(self, args): + self.args = args + self.async_handler = AsyncCheckpointHandler(args) + + def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): + """save unified checkpoint + + Args: + model (PretrainedModel): model to save + output_dir (str): save dir + safe_serialization (bool, optional): use safetensors. Defaults to False. + + Raises: + ValueError: if model is not an instance of `PretrainedModel` and the model cannot be saved + """ + if isinstance(model, PretrainedModel): + model_to_save = model + elif isinstance(unwrap_model(model), PretrainedModel): + model_to_save = unwrap_model(model) + elif isinstance(model, PrefixModelForCausalLM) or isinstance(model, LoRAModel): + model_to_save = model + else: + raise ValueError("Unified checkpoint only supports PretrainedModel, LoRAModel and PrefixModelForCausalLM!") + + # Under non distributed environment. + if paddle.distributed.get_world_size() <= 1: + save_single_card_checkpoint(model_to_save, output_dir) + return + + skip_save_model_weight = False + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in self.args.unified_checkpoint_config: + if is_need_master_weight(optimizer, is_fp16_or_bp16=(self.args.fp16 or self.args.bf16)): + logger.info( + f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save." + " The master weight will be loaded as model weights for next resumption." + ) + # not save model weight, load from master weight + skip_save_model_weight = True + + save_directory = output_dir + os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) # only for async save + + # save model weights + if not skip_save_model_weight: + state_dict, shard_file, sharded_index = unified_checkpoint_into_shards( + self.args, model_to_save, safe_serialization=True + ) + is_sync_save = True + if "async_save" in self.args.unified_checkpoint_config: + is_sync_save = False + self.async_handler._file_save_async_or_sync( + state_dict, + path=os.path.join(save_directory, shard_file), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="model_weight", + ) + if sharded_index is not None: + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + index_name = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + index_name = SAFE_WEIGHTS_INDEX_NAME + path = os.path.join(output_dir, index_name) + + if self.args.should_save: + with open(path, "w") as f: + json.dump(sharded_index, f, indent=4) + + if self.args.should_save: + save_model_config(model_to_save, save_directory) + + paddle.device.cuda.empty_cache() + + if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save: + world_size = paddle.distributed.get_world_size() + save_info = { + "world_size": world_size, + "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, + "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, + } + paddle.save(save_info, os.path.join(save_directory, ".saving_info")) + + def load_unified_checkpoint(self, model, resume_from_checkpoint: str): + """Load potential model checkpoint + + Args: + model (PretrainedModel): Your model to load + resume_from_checkpoint (str): path of the checkpoint to load + + Returns: + None + """ + if paddle.distributed.get_world_size() <= 1: + load_single_card_checkpoint(model, resume_from_checkpoint) + return + + local_resume = check_unified_checkpoint(self.args, model, resume_from_checkpoint, safe_serialization=True) + + if not local_resume: + logger.info("Begin to dynamically load unified checkpoint!") + load_unified_checkpoint_dynamically(self.args, model, resume_from_checkpoint, safe_serialization=True) + return + + if self.args.dataset_rank == 0 or self.args.use_expert_parallel: + load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) + + def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir): + paddle.device.cuda.empty_cache() + + # gather global master_weights status. + global_master_weights = reduce_master_weights_status(master_weights is not None) + if master_weights is None and global_master_weights: + master_weights = {} + + # get optimizer param mappings + static2struct_name_mappings = {} + state_dict = get_expected_state_dict(model) + for k, v in state_dict.items(): + static2struct_name_mappings[v.name] = k + + # rename optimizer param name + for key in list(optim_state_dict.keys()): + static_name, type_name = generate_base_static_name(key) + new_name = static2struct_name_mappings[static_name] + "/" + type_name + optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: + for key in list(master_weights.keys()): + master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + + no_sync_kname = [] + model_state_dict = get_expected_state_dict(model) + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + if self.args.use_expert_parallel: + for k in list(optim_state_dict.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + optim_state_dict.pop(k) + if master_weights is not None: + for k in list(master_weights.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + master_weights.pop(k) + + optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) + master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) + + is_sync_save = True + if "async_save" in self.args.unified_checkpoint_config: + is_sync_save = False + self.async_handler._file_save_async_or_sync( + optim_state_dict, + path=os.path.join(output_dir, optimizer_name), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="optimizer_weight", + ) + self.async_handler._file_save_async_or_sync( + master_weights, + path=os.path.join(output_dir, master_weights_name), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="master_weight", + ) + + def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): + # init and get optimizer LR_Scheduler + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) + master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) + optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name) + master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name) + has_master_weights = True if os.path.isfile(master_weights_path) else False + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings + optimizer_state_dict = load_file(optimizer_path) + if has_master_weights: + master_weights = load_file(master_weights_path) + + # rename and move to paddle.Tensor + for key in list(optimizer_state_dict.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + with device_guard(): + weight = paddle.Tensor(optimizer_state_dict.pop(key), zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + returned_optim_state_dict[key_name] = weight + returned_optim_state_dict[key_name].name = key_name + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + for key in list(master_weights.keys()): + static_name = struct2static_name_mappings[key] + with device_guard(): + weight = paddle.Tensor(master_weights.pop(key), zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + returned_optim_state_dict["master_weights"][static_name] = weight + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + return returned_optim_state_dict + + def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): + """save unified optimizer + + Args: + model (PretrainedModel): model used to get key mapping. + optimizer (Optimizer): optimizer to save + output_dir (str): Save directory. + signal_dir (str): Asynchronous saving signal directory. + + """ + + if paddle.distributed.get_world_size() <= 1: + save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal + return + + if is_sharding_split_param_mode(self.args): + optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer) + else: + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict["master_weights"] + optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: + self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir, signal_dir) + return + + # Split into naive optimizer params and master weights. + results = unified_optimizer_into_shards( + self.args, model, optim_state_dict, master_weights, safe_serialization=True + ) + master_weight_state_dict = None + if len(results) == 1: + optim_state_dict, shard_optim_file, sharded_optim_index = results[0] + else: + optim_state_dict, shard_optim_file, sharded_optim_index = results[0] + master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1] + + paddle.device.cuda.empty_cache() + save_directory = output_dir + os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) + + is_sync_save = True + if "async_save" in self.args.unified_checkpoint_config: + is_sync_save = False + self.async_handler._file_save_async_or_sync( + optim_state_dict, + path=os.path.join(save_directory, shard_optim_file), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="optimizer_weight", + ) + if master_weight_state_dict is not None: + self.async_handler._file_save_async_or_sync( + master_weight_state_dict, + path=os.path.join(save_directory, shard_master_weight_file), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="master_weight", + ) + + if sharded_optim_index is not None: + optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME + path = os.path.join(output_dir, optimizer_index_name) + if self.args.should_save: + with open(path, "w") as f: + json.dump(sharded_optim_index, f, indent=4) + + master_weights_name = SAFE_MASTER_WEIGHTS_INDEX_NAME + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in self.args.unified_checkpoint_config: + master_weights_name = SAFE_WEIGHTS_INDEX_NAME + master_path = os.path.join(output_dir, master_weights_name) + if master_weight_state_dict is not None: + if self.args.should_save: + with open(master_path, "w") as f: + json.dump(sharded_master_weight_index, f, indent=4) + + def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): + """Load potential model checkpoint + + Args: + model (PretrainedModel): Your model to load + resume_from_checkpoint (str): path of the checkpoint to load + + Returns: + None + """ + + if paddle.distributed.get_world_size() <= 1: + optim_state_dict = load_single_card_optimizer(model, optimizer, resume_from_checkpoint) + return optim_state_dict + + has_merge_optimizer_safetensors = distributed_isfile( + os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME) + ) + # If not having merge optimizer, then load non-merge optimizer. + if not has_merge_optimizer_safetensors: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: + returned_optim_state_dict = self.load_non_merge_optimizer( + model, + optimizer, + resume_from_checkpoint, + ) + return returned_optim_state_dict + else: + return None + + local_resume = check_unified_optimizer( + self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True + ) + if not local_resume: + logger.info("Begin to dynamically load unified optimizer!") + returned_optim_state_dict = load_unified_optimizer_dynamically( + self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True + ) + return returned_optim_state_dict + + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: + returned_optim_state_dict = load_unified_optimizer_locally( + self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True + ) + return returned_optim_state_dict + return None + + def unlink_shared_memory(self): + return self.async_handler.unlink_shared_memory() + + +def unified_checkpoint_into_shards( + args, + model_to_save, + safe_serialization=False, +): + """Get state_dict and config to save + + Args: + model_to_save (nn.Layer): model to, save + safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False. + + Returns: + tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name. + """ + paddle.device.cuda.empty_cache() + assert hasattr(model_to_save, "config") + + state_dict = get_expected_state_dict(model_to_save) + all_filter_keys = filter_params(model_to_save, state_dict) + + config_to_save = copy.deepcopy(model_to_save.config) + + if config_to_save.tensor_parallel_degree > 1: + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + tp_actions = model_to_save._get_tensor_parallel_convert_actions( + all_filter_keys, is_split=False, ignore_error=True + ) + else: + tp_actions = model_to_save.get_tensor_parallel_convert_actions( + model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True + ) + logger.info("Unified model tensor parallel weights in shards") + state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys) + + # build index json file + index_weight_file = {} + total_size = 0 + if isinstance(model_to_save, LoRAModel): + weights_name = SAFE_PEFT_WEIGHTS_NAME if safe_serialization else LORA_WEIGHTS_NAME + elif isinstance(model_to_save, PrefixModelForCausalLM): + weights_name = SAFE_PEFT_WEIGHTS_NAME if safe_serialization else PREFIX_WEIGHTS_NAME + else: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME + + shard_file = get_sharded_file_name(args, weights_name) + # renumerize shard_file name for expert_parallel. + if args.use_expert_parallel: + shard_file = rename_shard_file(args, shard_file, weights_name) + + for key, weight in state_dict.items(): + index_weight_file[key] = shard_file + total_size += weight.numel().item() * dtype_byte_size(weight.dtype) + + index_file_list, total_size_list = gather_sharded_object( + index_weight_file, total_size, use_expert_parallel=args.use_expert_parallel + ) + sharded_index = get_sharded_index( + index_file_list, + total_size_list, + ) + if sharded_index is not None: + if isinstance(model_to_save, LoRAModel): + sharded_index["type"] = "lora" + elif isinstance(model_to_save, PrefixModelForCausalLM): + sharded_index["type"] = "ptuning" + + paddle.device.cuda.empty_cache() + + return state_dict, shard_file, sharded_index + + +def unified_optimizer_into_shards( + args, + model, + optim_state_dict, + master_weights, + safe_serialization=False, +): + """Get optimizer state dict and master weight state dict. + + Args: + optimizer (Optimizer): optimizer to save. + safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False. + """ + paddle.device.cuda.empty_cache() + + # gather global master_weights status. + global_master_weights = reduce_master_weights_status(master_weights is not None) + if master_weights is None and global_master_weights: + master_weights = {} + + # get optimizer param mappings + static2struct_name_mappings = {} + state_dict = get_expected_state_dict(model) + fp32_weight = {} + for k, v in state_dict.items(): + static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + if args.dataset_rank > 0: # deal with different dataset rank. + continue + fp32_weight[k] = v + + # rename optimizer param + for key in list(optim_state_dict.keys()): + static_name, type_name = generate_base_static_name(key) + new_name = static2struct_name_mappings[static_name] + "/" + type_name + optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: + for key in list(master_weights.keys()): + master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) + + # filter optimizer param + if master_weights is not None: + filter_master_keys = filter_params(model, master_weights, is_optimizer=True) + filter_optim_keys = filter_params(model, optim_state_dict, is_optimizer=True) + + tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + tp_size = tp_group.nranks + + if tp_size > 1: + # get tp_actions + model_keys = [] + for key in optim_state_dict.keys(): + base_model_key = key.split("/")[0] + if base_model_key not in model_keys: + model_keys.append(base_model_key) + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=False, ignore_error=True) + else: + tp_actions = model.get_tensor_parallel_convert_actions( + model.config, model_keys, is_split=False, ignore_error=True + ) + logger.info("Unified optimizer tensor parallel in shards") + optim_state_dict = merge_tensor_parallel_for_optimizer( + optim_state_dict, + tp_actions, + filter_optim_keys, + state_dict if args.use_expert_parallel else None, + ) + paddle.device.cuda.empty_cache() + + if master_weights is not None: + logger.info("Unified master weight tensor parallel in shards") + master_weights = merge_tensor_parallel_for_optimizer( + master_weights, + tp_actions, + filter_master_keys, + state_dict if args.use_expert_parallel else None, + ) + paddle.device.cuda.empty_cache() + + # build index json file + index_optimizer_file, index_master_weight_file = {}, {} + total_optim_size, total_master_weight_size = 0, 0 + optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME + master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: + master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME + shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True) + shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True) + + for key, weight in optim_state_dict.items(): + index_optimizer_file[key] = shard_optimizer_file + total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) + + if master_weights is not None: + for key, weight in master_weights.items(): + index_master_weight_file[key] = shard_master_weight_file + total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) + + index_optimizer_filelist, total_optim_size_list = gather_sharded_object( + index_optimizer_file, + total_optim_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, + ) + sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) + if master_weights is not None: + index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( + index_master_weight_file, + total_master_weight_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, + ) + sharded_master_weight_index = get_sharded_index(index_master_weight_filelist, total_master_weight_size_list) + + if sharded_optim_index is not None: + if master_weights is not None: + sharded_optim_index["master_weights"] = True + else: + sharded_optim_index["master_weights"] = False + + paddle.device.cuda.empty_cache() + if master_weights is None: + return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)] + else: + return [ + (optim_state_dict, shard_optimizer_file, sharded_optim_index), + (master_weights, shard_master_weight_file, sharded_master_weight_index), + ] diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py new file mode 100644 index 000000000000..bad8dabbafa2 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -0,0 +1,748 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unified Checkpoint Utility Functions""" + +import copy +import os + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption +from paddlenlp.trainer.utils.helper import distributed_isfile +from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype +from paddlenlp.transformers.utils import dtype_byte_size +from paddlenlp.utils.distributed import distributed_allgather, distributed_gather +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_PEFT_WEIGHTS_INDEX_NAME, + PADDLE_WEIGHTS_INDEX_NAME, + PAST_KEY_VALUES_FILE_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import flatten_list +from paddlenlp.utils.tools import get_env_device + +FP32_MASTER = "fp32_master_0" +optimizer_scalar_name = [ + "beta1_pow_acc_0", + "beta2_pow_acc_0", +] +optimizer_non_scaler_name = [ + "moment1_0", + "moment2_0", + "velocity_0", +] # to be added + + +DEST_PLACE = paddle.CPUPlace() +if paddle.device.is_compiled_with_cuda(): + DEST_PLACE = paddle.CUDAPinnedPlace() + + +class UnifiedCheckpointOption(ExplicitEnum): + """ + "- skip_save_model_weight: do not save model weights when the masters weight exist\n" + "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" + " 2. if master weights does not exist, convert model weights to master weights when needed\n" + "- async_save: enable asynchronous saving checkpoints to disk\n" + "- enable_all_options: enable all optimization configurations\n" + """ + + SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight" + MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible" + ASYNC_SAVE = "async_save" + IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" + + +def unwrap_optimizer(optimizer): + while hasattr(optimizer, "_inner_opt") or hasattr(optimizer, "_optim"): + if hasattr(optimizer, "_inner_opt"): + optimizer = optimizer._inner_opt + if hasattr(optimizer, "_optim"): + optimizer = optimizer._optim + return optimizer + + +def is_need_master_weight(optimizer, is_fp16_or_bp16): + optimizer = unwrap_optimizer(optimizer) + if hasattr(optimizer, "_multi_precision"): + return optimizer._multi_precision and is_fp16_or_bp16 + else: + return False + + +def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization): + if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): + if not has_master_weight: + if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config: + index_filename_master_weights = ( + PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + ) + has_master_weight = True + logger.warning( + "The unified checkpoint does not contain master weight, " + "the model weight will be loaded as master weight." + ) + else: + raise ValueError( + "Can't find a valid unified master weight checkpoint," + f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to " + "load model checkpoint as master weight" + ) + else: + has_master_weight = True + index_filename_master_weights = ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME + ) + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: + index_filename_master_weights = ( + PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + ) + else: + has_master_weight = False + index_filename_master_weights = None + + return has_master_weight, index_filename_master_weights + + +def reduce_master_weights_status(has_master_weights=False): + """ + Get master_weight status througn tp, pp and sharding group. + """ + data = paddle.to_tensor([has_master_weights], dtype="int32") + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + + if tp_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pp_group) + if sharding_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=sharding_group) + + return data.item() > 0 + + +def select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=True): + """ + try select model weight index from model weight or master weight index. + """ + + # find model weight index file + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_PEFT_WEIGHTS_INDEX_NAME + else: + index_filename = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME + + index_filename_path = os.path.join(resume_from_checkpoint, index_filename) + identify_func = os.path.isfile if local else distributed_isfile + + if identify_func(index_filename_path): + return index_filename + else: + index_filename = PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME + index_filename_path = os.path.join(resume_from_checkpoint, index_filename) + + if identify_func(index_filename_path): + return index_filename + else: + raise ValueError("Can't find a valid unified model or master weight checkpoint to load.") + + +def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys): + """# convert param.name to + param.key/moment1_0 + or param.key/beta1_XXX + or param.key/beta2_XXX + Args: + tp_actions (dict): dictionay of tensor parallel actions {key: action} + optimizer_loaded_keys (list or set): [param.key1/moment1_0, param.key2/beta1_XXX, param.key3/beta2_XXX] + Returns: + dict: new dictionay of tensor parallel actions {key: action} + """ + new_actions = {} + for key in optimizer_loaded_keys: + key_base, typename = key.split("/") + if typename in optimizer_non_scaler_name and key_base in tp_actions: + new_actions[key] = tp_actions[key_base] + return new_actions + + +def get_expected_state_dict(model_to_save): + """ + Get trainable state_dict of model_to_save. + """ + if isinstance(model_to_save, PretrainedModel): + state_dict = model_to_save.state_dict() + if ( + hasattr(model_to_save.config, "tie_word_embeddings") + and model_to_save.config.tie_word_embeddings + and hasattr(model_to_save, "_tied_weights_keys") + and model_to_save._tied_weights_keys is not None + ): + for key in model_to_save._tied_weights_keys: + if key in state_dict: + state_dict.pop(key) + elif isinstance(model_to_save, LoRAModel): + state_dict = model_to_save.get_trainable_state_dict() + elif isinstance(model_to_save, PrefixModelForCausalLM): + state_dict = model_to_save.prefix_encoder.state_dict() + + return state_dict + + +def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weights=False): + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + sharding_rank = sharding_group.rank + in_sharding_parallel_model = sharding_group.nranks > 1 + if in_sharding_parallel_model: + params2rank = optimizer._param2rank + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + + expected_keys = [] + for key in list(sharded_metadata["all_optimizer_keys"]): + key_name = key.split("/")[0] + if ( + is_master_weights + and key_name in model_state_dict + and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 + ): + continue + + if args.use_expert_parallel and args.data_parallel_rank > 0: + if key_name in model_state_dict and not getattr(model_state_dict[key_name], "no_sync", False): + continue + + static_name = struct2static_name_mappings.get(key_name, None) + + if in_sharding_parallel_model: + params_rank = params2rank.get(static_name, None) + if params_rank == sharding_rank: + expected_keys.append(key) + else: + if static_name is not None: + expected_keys.append(key) + expected_keys = set(expected_keys) + + loaded_keys = sharded_metadata["all_optimizer_keys"] + missing_keys = expected_keys - set(loaded_keys) + if len(missing_keys) > 0: + raise ValueError(f"optimizer missing weights keys: {missing_keys}") + + return expected_keys + + +def get_optimizer_shard_files(optimizer_path, index_filename): + """ + For a given model: + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + + import json + + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a optimizer index ({index_filename}) in {optimizer_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_optimizer_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + sharded_metadata["master_weights"] = index.get("master_weights", False) + + file_map = {file: set() for file in shard_filenames} + for weight, file in index["weight_map"].items(): + file_map[file].add(weight) + + sharded_metadata["file_map"] = file_map + + # First, let's deal with local folder. + # TODO: if optimizer_path is a folder, we should check if the optimizer is already cached or not. + if os.path.isdir(optimizer_path): + shard_filenames = [os.path.join(optimizer_path, f) for f in shard_filenames] + return shard_filenames, sharded_metadata + + +def generate_base_static_name(vname): + """ + Return base static name and specific type name, like [embedding_0.w_0, moment1_0] + """ + if FP32_MASTER in vname: + vname = vname.split("_" + FP32_MASTER + "_") + return vname[0], vname[1] + else: + # Directly deal with type names, for example: moe_gate_1_moment1_0. + type_names = optimizer_scalar_name + optimizer_non_scaler_name + for name in type_names: + if name in vname: + a = vname.split(name)[0][:-1] + b = name + return a, b + + +def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): + """ + Move large tensor merge process to CPU, in order to avoid OOM. + """ + num_rows = tensor.shape[0] + num_splits = 4 + parts = np.array_split(np.arange(num_rows), num_splits) + splits = [len(part) for part in parts] + split_parts = np.insert(np.cumsum(splits), 0, 0) + split_tensors = [] + for i in range(num_splits): + if get_env_device() == "xpu": + ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False) + else: + ret = distributed_gather( + tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False + ) + # Copy to CPUPlace temporarily, may lower speed. + if ret is not None: + ret = [t.cpu() for t in ret] + split_tensors.append(ret) + concat_tensors = [] + if is_dst: + for i in range(tp_group.nranks): + tmp = [] + for j in range(num_splits): + tmp.append(split_tensors[j][i]) + concat_tensors.append(paddle.concat(tmp)) + tensor = tp_action(concat_tensors) + else: + tensor = None + return tensor + + +def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): + """ + Merge tensor parallel according to tp_actions, used for model weight. + """ + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() + tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + # filter actions for pipeline mode + if hcg.get_pipe_parallel_group().nranks > 1: + filter_keys = set([y for x in all_filter_keys for y in x]) + for key in list(tp_actions.keys()): + if key not in filter_keys: + tp_actions.pop(key) + + state_dict_to_save = {} + max_key_len = max([len(_) for _ in all_filter_keys]) + for i in range(max_key_len): + for j, filter_keys in enumerate(all_filter_keys): + is_dst = tp_rank == j + if i > len(filter_keys) - 1: + continue + key = filter_keys[i] + tensor = state_dict[key] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and not getattr(tensor, "no_sync", False): + continue + if key in tp_actions: + # Get tensor size + tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks + if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold + tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst) + else: + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=tp_group, offload=False) + else: + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) + action = tp_actions.pop(key) + tensor = action(ret) if is_dst else None + else: + if is_dst: + tensor = tensor._copy_to(DEST_PLACE, False) if tensor.place.is_cpu_place() else tensor + else: + tensor = None + + if is_dst: + state_dict_to_save[key] = tensor + + if len(tp_actions) > 0: + for x in tp_actions.keys(): + logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + + return state_dict_to_save + + +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): + """ + Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight. + """ + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() + tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + no_sync_kname = [] + if model_state_dict is not None: + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + state_dict_to_save = {} + max_key_len = max([len(_) for _ in all_filter_keys]) + for i in range(max_key_len): + for j, filter_keys in enumerate(all_filter_keys): + is_dst = tp_rank == j + if i > len(filter_keys) - 1: + continue + # get base model key + model_key = filter_keys[i].split("/")[0] + tensor = state_dict[filter_keys[i]] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and model_key not in no_sync_kname: + continue + if model_key in tp_actions: + # for example: beta1, beta2 + if tensor.numel().item() == 1: + if is_dst: + tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor + else: + tensor = None + else: + # Get tensor size + tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks + if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold + tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst) + else: + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=tp_group, offload=False) + else: + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) + action = tp_actions[model_key] + tensor = action(ret) if is_dst else None + else: + if is_dst: + tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor + else: + tensor = None + + if is_dst: + state_dict_to_save[filter_keys[i]] = tensor + + return state_dict_to_save + + +def filter_params(model_to_save, state_dict, is_optimizer=False): + """ + Group according to the size of the tensor, aiming to make the weight size + stored on each device as equal as possible. + """ + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + + tp_size = tp_group.nranks + tp_rank = tp_group.rank + + # for pure sharding or pure pp + if tp_size <= 1: + return [list(state_dict.keys())] + + filter_tensor_list = [[] for _ in range(tp_size)] + + if tp_rank == 0: + tensor_bytes_dict = {} + model_state_dict = get_expected_state_dict(model_to_save) + for (k, v) in state_dict.items(): + model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v + if hasattr(model_v, "is_distributed") and model_v.is_distributed: + tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) + else: + tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) + + filter_tensor_list = [] + current_block = [] + current_block_size = 0 + total_size = 0 + + max_shard_size = (sum(tensor_bytes_dict.values()) + tp_size - 1) // tp_size + + for index, (key, weight_size) in enumerate(tensor_bytes_dict.items()): + # If this weight is going to tip up over the maximal size, we split. + # if current_block_size + weight_size > max_shard_size: + if total_size + weight_size > max_shard_size * (len(filter_tensor_list) + 1) or ( + len(tensor_bytes_dict) - index < (tp_size - len(filter_tensor_list)) + ): + # fix if the first param is large than max_shard_size + if len(current_block) > 0: + filter_tensor_list.append(current_block) + current_block = [] + current_block_size = 0 + + current_block.append(key) + current_block_size += weight_size + total_size += weight_size + + filter_tensor_list.append(current_block) + if len(filter_tensor_list) < tp_size: + filter_tensor_list.extend([[] for i in range(tp_size - len(filter_tensor_list))]) + + dist.broadcast_object_list( + filter_tensor_list, + src=hcg.get_model_parallel_group_src_rank(), + group=tp_group, + ) + + return filter_tensor_list + + +def get_sharded_file_name(args, file_name, is_optimizer=False): + """ + Get safetensors file name for saving. + """ + if not is_optimizer: + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + size = sd_degree if args.use_expert_parallel else args.dataset_world_size + shard_file = file_name.replace( + ".pdparams", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams", + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", + ) + else: + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + size = dp_group.nranks if not args.use_expert_parallel else 1 + shard_file = file_name.replace( + ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", + ) + shard_file = shard_file.replace( + ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdopt" + ) + return shard_file + + +def get_sharded_index( + index_file_list, + total_size_list, +): + """ + Save safetensors index json file, including metadata and weight_map. + """ + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + if local_rank == 0: + sharded_index_json = {} + + sharded_index_json["metadata"] = {"total_size": sum(total_size_list)} + + weight_map = {} + for i, _ in enumerate(index_file_list): + weight_map.update(index_file_list[i]) + + sharded_index_json["weight_map"] = weight_map + return sharded_index_json + + return None + + +def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): + """ + All gather sharded files list across different groups. + """ + index_file_list, total_size_list = [], [] + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + + logger.info( + f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}." + ) + + if tp_group.nranks > 1: + dist.all_gather_object(index_file_list, index_file, tp_group) + dist.all_gather_object(total_size_list, total_size, tp_group) + if pp_group.nranks > 1: + pp_index_file_list = [] + pp_total_size_list = [] + dist.all_gather_object( + pp_index_file_list, index_file_list if len(index_file_list) > 0 else index_file, pp_group + ) + dist.all_gather_object( + pp_total_size_list, total_size_list if len(total_size_list) > 0 else total_size, pp_group + ) + index_file_list = pp_index_file_list + total_size_list = pp_total_size_list + + index_file_list = flatten_list(index_file_list) + total_size_list = flatten_list(total_size_list) + + # for pure sharding + if len(index_file_list) == 0 and len(total_size_list) == 0: + index_file_list = [index_file] + total_size_list = [total_size] + + if use_expert_parallel: + data_group = hcg.get_data_parallel_group() + if data_group.nranks > 1: + data_index_file_list = [] + data_total_size_list = [] + dist.all_gather_object(data_index_file_list, index_file_list, data_group) + dist.all_gather_object(data_total_size_list, total_size_list, data_group) + index_file_list = flatten_list(data_index_file_list) + total_size_list = flatten_list(data_total_size_list) + + if is_optimizer: + sharding_group = hcg.get_sharding_parallel_group() + if sharding_group.nranks > 1: + sharding_index_file_list = [] + sharding_total_size_list = [] + dist.all_gather_object(sharding_index_file_list, index_file_list, sharding_group) + dist.all_gather_object(sharding_total_size_list, total_size_list, sharding_group) + index_file_list = flatten_list(sharding_index_file_list) + total_size_list = flatten_list(sharding_total_size_list) + + return index_file_list, total_size_list + + +def rename_shard_file(args, shard_file, file_name): + """ + Rename shard file when using expert_parallel. + """ + assert args.use_expert_parallel, "only expert_parallel need to use this function" + + shard_file_list = [] + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + data_group = hcg.get_data_parallel_group() + + if tp_group.nranks > 1: + dist.all_gather_object(shard_file_list, shard_file, tp_group) + if pp_group.nranks > 1: + pp_shard_file_list = [] + dist.all_gather_object( + pp_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, pp_group + ) + shard_file_list = flatten_list(pp_shard_file_list) + if data_group.nranks > 1: + data_shard_file_list = [] + dist.all_gather_object( + data_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, data_group + ) + shard_file_list = flatten_list(data_shard_file_list) + + new_index = shard_file_list.index(shard_file) + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + shard_file = file_name.replace( + ".pdparams", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.pdparams", + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", + ) + return shard_file + + +def save_prefix_past_key_value(model_to_save, save_directory): + """ + Used only for PrefixModelForCausalLM. + """ + past_key_value = model_to_save.prefix_encoder(model_to_save.prefix_tokens.unsqueeze(0).expand([1, -1])) + past_key_value = past_key_value.reshape( + [ + model_to_save.prefix_config.num_prefix_tokens, + 2, + model_to_save.prefix_config.num_hidden_layers, + model_to_save.num_heads, + model_to_save.head_dim, + ] + ) + past_key_value = paddle.transpose(past_key_value, perm=[2, 1, 3, 0, 4]).cpu().numpy() + np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) + + +def is_sharding_split_param_mode(args): + return ( + args.sharding_parallel_degree > 1 + and ShardingOption.SHARD_OP in args.sharding + and "split_param" in args.sharding_parallel_config + ) + + +def save_model_config(model_to_save, save_directory): + """ + Save model config. + """ + + def save_config(model_to_save): + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.dtype = str(dtype).split(".")[1] + config_to_save = copy.deepcopy(model_to_save.config) + + if config_to_save.tensor_parallel_degree > 1: + # do we need to change? + config_to_save.tensor_parallel_degree = 1 + + return config_to_save + + # Save prefix model past_key_values + if isinstance(model_to_save, PrefixModelForCausalLM): + save_prefix_past_key_value(model_to_save, save_directory) + model_to_save.prefix_config.save_pretrained(save_directory) + if isinstance(model_to_save, LoRAModel): + model_to_save.lora_config.save_pretrained(save_directory) + + # save the config + config_to_save = save_config(model_to_save) + # Attach architecture to the config + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + config_to_save.architectures = [model_to_save.model.__class__.__name__] + else: + config_to_save.architectures = [model_to_save.__class__.__name__] + + config_to_save.save_pretrained(save_directory) + # save generation config + if model_to_save.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) diff --git a/paddlenlp/utils/nested.py b/paddlenlp/utils/nested.py index 4e800231843c..43f012aa3d0e 100644 --- a/paddlenlp/utils/nested.py +++ b/paddlenlp/utils/nested.py @@ -116,3 +116,13 @@ def nested_copy_place(inputs, place=None, blocking=False): if isinstance(inputs, paddle.Tensor): inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking) return inputs + + +def flatten_list(nested_list): + flattened_list = [] + for item in nested_list: + if isinstance(item, list): + flattened_list.extend(flatten_list(item)) + else: + flattened_list.append(item) + return flattened_list diff --git a/tests/trainer/test_unified_checkpoint.py b/tests/trainer/test_unified_checkpoint.py index 17fe0f14f9ea..8f5a1dfe7236 100644 --- a/tests/trainer/test_unified_checkpoint.py +++ b/tests/trainer/test_unified_checkpoint.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from paddlenlp.trainer.plugins.unified_checkpoint import UnifiedCheckpointOption +from paddlenlp.trainer.unified_checkpoint.utils import UnifiedCheckpointOption from tests.parallel_launch import TestMultipleGpus from tests.testing_utils import ( require_paddle_at_least_2_gpu,