Skip to content

Commit

Permalink
[Unified Checkpoint] fix checkpoint names (#7794)
Browse files Browse the repository at this point in the history
when skipping model weighs save and saving master weights as model weights, unified checkpoint needs choose the model weights to load into master weights.
  • Loading branch information
DrownFish19 committed Jan 8, 2024
1 parent bb9062e commit 672ee98
Showing 1 changed file with 51 additions and 35 deletions.
86 changes: 51 additions & 35 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,44 +115,47 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
else:
raise ValueError("Unified checkpoint only supports PretrainedModel")

skip_save_model_weight = False
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
logger.info(
f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save."
"The master weight will be loaded as model weights for next resumption."
)
# not save model weight, load from master weight
return
config_to_save = None
state_dict, config_to_save, shard_file, sharded_index = unified_checkpoint_into_shards(
args, model_to_save, safe_serialization=safe_serialization
)
skip_save_model_weight = True

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)

is_sync_save = True
if "async_save" in args.unified_checkpoint_config:
is_sync_save = False
file_save_async_or_sync(
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
)
# save model weights
if not skip_save_model_weight:
state_dict, shard_file, sharded_index = unified_checkpoint_into_shards(
args, model_to_save, safe_serialization=safe_serialization
)
is_sync_save = True
if "async_save" in args.unified_checkpoint_config:
is_sync_save = False
file_save_async_or_sync(
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
)

if sharded_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)

with open(path, "w") as f:
json.dump(sharded_index, f, indent=4)

# save the config
config_to_save = save_config(model_to_save)
# Attach architecture to the config
config_to_save.architectures = [model_to_save.__class__.__name__]
# Save the config
if args.should_save:
config_to_save.save_pretrained(save_directory)

if sharded_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)

with open(path, "w") as f:
json.dump(sharded_index, f, indent=4)


def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None:
"""Load potential model checkpoint
Expand Down Expand Up @@ -252,6 +255,18 @@ def _remove_unused_keys(
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")


def save_config(model_to_save):
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
config_to_save = copy.deepcopy(model_to_save.config)

if config_to_save.tensor_parallel_degree > 1:
# do we need to change?
config_to_save.tensor_parallel_degree = 1

return config_to_save


def unified_checkpoint_into_shards(
args,
model_to_save,
Expand All @@ -272,8 +287,6 @@ def unified_checkpoint_into_shards(

all_filter_keys = filter_params(model_to_save, state_dict)

dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
config_to_save = copy.deepcopy(model_to_save.config)

if config_to_save.tensor_parallel_degree > 1:
Expand All @@ -282,10 +295,6 @@ def unified_checkpoint_into_shards(
)
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)

if config_to_save.tensor_parallel_degree > 1:
# do we need to change?
config_to_save.tensor_parallel_degree = 1

# build index json file
index_weight_file = {}
total_size = 0
Expand All @@ -302,7 +311,7 @@ def unified_checkpoint_into_shards(
total_size_list,
)

return state_dict, config_to_save, shard_file, sharded_index
return state_dict, shard_file, sharded_index


def save_unified_optimizer(args, model, optimizer, output_dir, safe_serialization=False):
Expand Down Expand Up @@ -343,16 +352,17 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
)

if sharded_optim_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_OPTIMIZER_INDEX_NAME)
master_path = os.path.join(output_dir, PADDLE_MASTER_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME)
master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME)

optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME if safe_serialization else PADDLE_OPTIMIZER_INDEX_NAME
path = os.path.join(output_dir, optimizer_index_name)
with open(path, "w") as f:
json.dump(sharded_optim_index, f, indent=4)

master_weights_name = (
SAFE_MASTER_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
master_weights_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME
master_path = os.path.join(output_dir, master_weights_name)
if master_weight_state_dict is not None:
with open(master_path, "w") as f:
json.dump(sharded_master_weight_index, f, indent=4)
Expand Down Expand Up @@ -561,6 +571,8 @@ def unified_optimizer_into_shards(
total_optim_size, total_master_weight_size = 0, 0
optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME
master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME
shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True)
shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True)

Expand Down Expand Up @@ -1648,6 +1660,10 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
index_filename_master_weights = (
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
index_filename_master_weights = (
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
)
else:
has_master_weight = False
index_filename_master_weights = None
Expand Down

0 comments on commit 672ee98

Please sign in to comment.