From 1abdd2cb6d8c2f4d1d980a47b0341b29adc49acb Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 8 Jan 2024 03:45:26 +0000 Subject: [PATCH 1/3] fix(unified checkpoint): add config save --- .../trainer/plugins/unified_checkpoint.py | 65 +++++++++++-------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 0d8a93dd9151..73d894e1143b 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -115,6 +115,7 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati else: raise ValueError("Unified checkpoint only supports PretrainedModel") + skip_save_model_weight = False if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): logger.info( @@ -122,37 +123,39 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati "The master weight will be loaded as model weights for next resumption." ) # not save model weight, load from master weight - return - config_to_save = None - state_dict, config_to_save, shard_file, sharded_index = unified_checkpoint_into_shards( - args, model_to_save, safe_serialization=safe_serialization - ) + skip_save_model_weight = True save_directory = output_dir os.makedirs(save_directory, exist_ok=True) - is_sync_save = True - if "async_save" in args.unified_checkpoint_config: - is_sync_save = False - file_save_async_or_sync( - state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save - ) + # save model weights + if skip_save_model_weight: + state_dict, shard_file, sharded_index = unified_checkpoint_into_shards( + args, model_to_save, safe_serialization=safe_serialization + ) + is_sync_save = True + if "async_save" in args.unified_checkpoint_config: + is_sync_save = False + file_save_async_or_sync( + state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save + ) + if sharded_index is not None: + if not safe_serialization: + path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME) + else: + path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME) + + with open(path, "w") as f: + json.dump(sharded_index, f, indent=4) + + # save the config + config_to_save = save_config(model_to_save) # Attach architecture to the config config_to_save.architectures = [model_to_save.__class__.__name__] - # Save the config if args.should_save: config_to_save.save_pretrained(save_directory) - if sharded_index is not None: - if not safe_serialization: - path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME) - else: - path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME) - - with open(path, "w") as f: - json.dump(sharded_index, f, indent=4) - def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None: """Load potential model checkpoint @@ -252,6 +255,18 @@ def _remove_unused_keys( raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") +def save_config(model_to_save): + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.dtype = str(dtype).split(".")[1] + config_to_save = copy.deepcopy(model_to_save.config) + + if config_to_save.tensor_parallel_degree > 1: + # do we need to change? + config_to_save.tensor_parallel_degree = 1 + + return config_to_save + + def unified_checkpoint_into_shards( args, model_to_save, @@ -272,8 +287,6 @@ def unified_checkpoint_into_shards( all_filter_keys = filter_params(model_to_save, state_dict) - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.dtype = str(dtype).split(".")[1] config_to_save = copy.deepcopy(model_to_save.config) if config_to_save.tensor_parallel_degree > 1: @@ -282,10 +295,6 @@ def unified_checkpoint_into_shards( ) state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys) - if config_to_save.tensor_parallel_degree > 1: - # do we need to change? - config_to_save.tensor_parallel_degree = 1 - # build index json file index_weight_file = {} total_size = 0 @@ -302,7 +311,7 @@ def unified_checkpoint_into_shards( total_size_list, ) - return state_dict, config_to_save, shard_file, sharded_index + return state_dict, shard_file, sharded_index def save_unified_optimizer(args, model, optimizer, output_dir, safe_serialization=False): From 96253fc8028ca1771eb4d88b66a68f407893b6b9 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 8 Jan 2024 04:04:20 +0000 Subject: [PATCH 2/3] fix(unified checkpoint): name change change master weigths name to model weigths name when SKIP_SAVE_MODEL_WEIGHT --- paddlenlp/trainer/plugins/unified_checkpoint.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 73d894e1143b..af7bad4de7cd 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -352,16 +352,17 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio ) if sharded_optim_index is not None: - if not safe_serialization: - path = os.path.join(output_dir, PADDLE_OPTIMIZER_INDEX_NAME) - master_path = os.path.join(output_dir, PADDLE_MASTER_WEIGHTS_INDEX_NAME) - else: - path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME) - master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME) - + optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME if safe_serialization else PADDLE_OPTIMIZER_INDEX_NAME + path = os.path.join(output_dir, optimizer_index_name) with open(path, "w") as f: json.dump(sharded_optim_index, f, indent=4) + master_weights_name = ( + SAFE_MASTER_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_INDEX_NAME + ) + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: + master_weights_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME + master_path = os.path.join(output_dir, master_weights_name) if master_weight_state_dict is not None: with open(master_path, "w") as f: json.dump(sharded_master_weight_index, f, indent=4) @@ -570,6 +571,8 @@ def unified_optimizer_into_shards( total_optim_size, total_master_weight_size = 0, 0 optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: + master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True) shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True) From c78b1d96ca4c3250684a9ab3a6ff192291676c66 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 8 Jan 2024 05:24:47 +0000 Subject: [PATCH 3/3] fix(unified checkpoint): model weights load when skipping model weighs save and saving master weights as model weights, unified checkpoint needs choose the model weights to load into master weights. --- paddlenlp/trainer/plugins/unified_checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index af7bad4de7cd..1ed33ef21f88 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -129,7 +129,7 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati os.makedirs(save_directory, exist_ok=True) # save model weights - if skip_save_model_weight: + if not skip_save_model_weight: state_dict, shard_file, sharded_index = unified_checkpoint_into_shards( args, model_to_save, safe_serialization=safe_serialization ) @@ -1660,6 +1660,10 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali index_filename_master_weights = ( PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME ) + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: + index_filename_master_weights = ( + PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + ) else: has_master_weight = False index_filename_master_weights = None