Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Checkpoint compression #9183

Merged
merged 44 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cd4e5e0
checkpoint compression init
wtmlon Sep 23, 2024
7684576
add ckpt quant argument
wtmlon Sep 24, 2024
afcecad
add ckpt quant ci
wtmlon Oct 11, 2024
d8f3351
fix ci
wtmlon Oct 11, 2024
434bd4c
fix lint
wtmlon Oct 11, 2024
a98fb8b
remove stage O2, change O3 --> O2
wtmlon Oct 11, 2024
2e5c73b
support async save
wtmlon Oct 11, 2024
6b1f3bf
file adjustment
wtmlon Oct 14, 2024
c4a80e7
magic string remove
wtmlon Oct 14, 2024
ae305a9
ci fix
wtmlon Oct 14, 2024
fd6ad57
ci fix, code refinement
wtmlon Oct 14, 2024
f766d15
function extraction
wtmlon Oct 15, 2024
e74b68b
fix ci
wtmlon Oct 15, 2024
a7b053d
code refinement
wtmlon Oct 15, 2024
10b1064
fix ci
wtmlon Oct 15, 2024
ad1dc75
fix ci
wtmlon Oct 15, 2024
fb2c2e9
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Oct 16, 2024
a1c35af
support non merge tp ckpt quantization
wtmlon Oct 18, 2024
f8530c0
fix ci
wtmlon Oct 18, 2024
4e21fb9
update
wtmlon Oct 18, 2024
a602fe5
fix bug
wtmlon Oct 21, 2024
55b8639
code refactor
wtmlon Oct 25, 2024
3a87734
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Oct 25, 2024
a3073aa
fix lint
wtmlon Oct 25, 2024
8a8aca7
fix ci
wtmlon Oct 25, 2024
bab5235
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Oct 28, 2024
c3c500d
del old uc.py
wtmlon Oct 28, 2024
a45c7f6
fix lint
wtmlon Oct 28, 2024
a4a3e23
add mgpu ci
wtmlon Oct 28, 2024
2330839
fix ci
wtmlon Oct 28, 2024
3fcd471
multi thread loading
wtmlon Oct 28, 2024
f57aab5
fix lint
wtmlon Oct 28, 2024
50ee148
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Oct 29, 2024
75a1011
fix bug
wtmlon Nov 5, 2024
ffd0823
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Nov 5, 2024
4947a8c
refactor code
wtmlon Nov 7, 2024
3eaebbb
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
wtmlon Nov 19, 2024
a6b2236
add comment
wtmlon Nov 19, 2024
a5d0afa
fix lint
wtmlon Nov 19, 2024
fdd92a8
add comment
wtmlon Nov 19, 2024
b2b20be
add comment
wtmlon Nov 19, 2024
432e97c
fix bug
wtmlon Nov 20, 2024
5eb201c
fix bugs when ckpt no quant and no master weight
wtmlon Nov 21, 2024
b2bcf16
remove uni-test
wtmlon Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def from_pretrained(cls, model, lora_path, **kwargs):
pre_tensor_parallel_split = True
tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True)
state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys
shard_file,
tp_actions if pre_tensor_parallel_split else None,
expected_keys,
)
error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "")
del state_dict
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/peft/prefix/prefix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ def from_pretrained(
pre_tensor_parallel_split = True
tp_actions = prefix_model._get_tensor_parallel_convert_actions(is_split=True)
state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys
shard_file,
tp_actions if pre_tensor_parallel_split else None,
expected_keys,
)
error_msgs += _load_state_dict_into_model(prefix_model.prefix_encoder, state_dict, "")
del state_dict
Expand Down
109 changes: 93 additions & 16 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
)
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import (
BETA1_KEYNAME,
BETA2_KEYNAME,
LORA_WEIGHTS_NAME,
MOMENT1_KEYNAME,
MOMENT2_KEYNAME,
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
PADDLE_MASTER_WEIGHTS_NAME,
PADDLE_OPTIMIZER_INDEX_NAME,
Expand Down Expand Up @@ -86,6 +90,7 @@
_traverse_copy_to_shm,
create_meta_dict,
)
from .unified_checkpoint_quantization import quant_unified_optimizer

FP32_MASTER = "fp32_master_0"
optimizer_scalar_name = [
Expand Down Expand Up @@ -115,6 +120,7 @@

SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight"
MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible"
REMOVE_MASTER_WEIGHT = "remove_master_weight"
ASYNC_SAVE = "async_save"
IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer"

Expand Down Expand Up @@ -149,11 +155,15 @@
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
def _file_save_async_or_sync(
self, state_dict, path, is_sync=True, state_dict_type="model_weight", ckpt_quant_stage="O0"
):

Check warning on line 160 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L158-L160

Added lines #L158 - L160 were not covered by tests
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()

state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage)

Check warning on line 166 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L166

Added line #L166 was not covered by tests
safe_save_file(state_dict, path, metadata={"format": "np"})
else:
if state_dict_type == "model_weight":
Expand Down Expand Up @@ -221,6 +231,7 @@
self._lock,
state_dict_type,
self.global_rank,
ckpt_quant_stage,
),
)
self._process_optimizer_weight.start()
Expand All @@ -246,6 +257,7 @@
lock,
state_dict_type,
global_rank,
ckpt_quant_stage="O0",
):
shm = shared_memory.SharedMemory(name=shm_name)
while True:
Expand All @@ -258,6 +270,9 @@
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
state_dict = quant_unified_optimizer(
state_dict, state_dict_type, ckpt_quant_stage
) # ckpt quantization

Check warning on line 275 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L273-L275

Added lines #L273 - L275 were not covered by tests
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
Expand Down Expand Up @@ -356,7 +371,6 @@
if self.args.should_save:
config_to_save.save_pretrained(save_directory)
paddle.device.cuda.empty_cache()

if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save:
world_size = paddle.distributed.get_world_size()
save_info = {
Expand Down Expand Up @@ -433,6 +447,7 @@
path=os.path.join(output_dir, optimizer_name),
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
ckpt_quant_stage=model.config.ckpt_quant_stage,

Check warning on line 450 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L450

Added line #L450 was not covered by tests
)
self._file_save_async_or_sync(
master_weights,
Expand Down Expand Up @@ -523,6 +538,7 @@
path=os.path.join(save_directory, shard_optim_file),
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
ckpt_quant_stage=model.config.ckpt_quant_stage,
)
if master_weight_state_dict is not None:
self._file_save_async_or_sync(
Expand Down Expand Up @@ -626,7 +642,10 @@

# save checkpoint
self._file_save_async_or_sync(
state_dict, path=os.path.join(output_dir, weight_filename), is_sync=True, state_dict_type="model_weight"
state_dict,
path=os.path.join(output_dir, weight_filename),
is_sync=True,
state_dict_type="model_weight",

Check warning on line 648 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L648

Added line #L648 was not covered by tests
)

if isinstance(model_to_save, PrefixModelForCausalLM):
Expand Down Expand Up @@ -659,6 +678,11 @@
static_name, type_name = generate_base_static_name(key)
new_name = static2struct_name_mappings[static_name] + "/" + type_name
optim_state_dict[new_name] = optim_state_dict.pop(key)

if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in self.args.unified_checkpoint_config:
logger.info("Skip master weight saving.")
master_weights = None

Check warning on line 685 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L682-L685

Added lines #L682 - L685 were not covered by tests
if master_weights is not None:
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)
Expand Down Expand Up @@ -700,6 +724,7 @@
path=os.path.join(output_dir, "optimizer-00001-of-00001.safetensors"),
is_sync=True,
state_dict_type="optimizer_weight",
ckpt_quant_stage=model.config.ckpt_quant_stage,

Check warning on line 727 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L727

Added line #L727 was not covered by tests
)
if master_weights is not None:
self._file_save_async_or_sync(
Expand Down Expand Up @@ -800,7 +825,11 @@
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
shard_file,
tp_actions if pre_tensor_parallel_split else None,
expected_keys,

Check warning on line 830 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L828-L830

Added lines #L828 - L830 were not covered by tests
device="expected",
ckpt_quant_stage=model.config.ckpt_quant_stage,

Check warning on line 832 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L832

Added line #L832 was not covered by tests
)

if not pre_tensor_parallel_split:
Expand Down Expand Up @@ -984,10 +1013,22 @@
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)

# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
state_dict = load_state_dict(
shard_file,
tp_actions,
expected_keys,

Check warning on line 1019 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1018-L1019

Added lines #L1018 - L1019 were not covered by tests
device="expected",
ckpt_quant_stage=model.config.ckpt_quant_stage,
)

Check warning on line 1022 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1021-L1022

Added lines #L1021 - L1022 were not covered by tests
else:
# for pipeline model, we don't need to use tp_actions
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")
state_dict = load_state_dict(
shard_file,
None,

Check warning on line 1027 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1027

Added line #L1027 was not covered by tests
expected_keys,
device="expected",
ckpt_quant_stage=model.config.ckpt_quant_stage,
)

Check warning on line 1031 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1031

Added line #L1031 was not covered by tests

returned_state_dict.update(state_dict)
# force memory release
Expand All @@ -1000,6 +1041,7 @@
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
)

# rename optimizer param
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
Expand Down Expand Up @@ -1057,6 +1099,11 @@
static_name, type_name = generate_base_static_name(key)
new_name = static2struct_name_mappings[static_name] + "/" + type_name
optim_state_dict[new_name] = optim_state_dict.pop(key)

if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in args.unified_checkpoint_config:

Check warning on line 1103 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1102-L1103

Added lines #L1102 - L1103 were not covered by tests
logger.info("Skip master weight saving.")
master_weights = None

Check warning on line 1105 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1105

Added line #L1105 was not covered by tests

if master_weights is not None:
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)
Expand Down Expand Up @@ -1707,7 +1754,9 @@
if len(missing_keys) > 0:
raise ValueError(f"Missing keys: {missing_keys}")

state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys)
state_dict = load_state_dict(
resolved_archive_file[0], None, expected_keys, ckpt_quant_stage=model.config.ckpt_quant_stage
)

Check warning on line 1759 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1757-L1759

Added lines #L1757 - L1759 were not covered by tests
error_msgs = _load_state_dict_into_model(model, state_dict, "")
del state_dict
gc.collect()
Expand Down Expand Up @@ -1737,9 +1786,13 @@
)
expected_keys_mw = sharded_metadata_mw["all_optimizer_keys"]

state_dict_optim = load_state_dict(resolved_archive_file[0], None, expected_keys)
state_dict_optim = load_state_dict(
resolved_archive_file[0], None, expected_keys, ckpt_quant_stage=model.config.ckpt_quant_stage
)

Check warning on line 1791 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1790-L1791

Added lines #L1790 - L1791 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,不应该从args读,下同

if has_master_weights:
state_dict_optim_mw = load_state_dict(resolved_archive_file_mw[0], None, expected_keys_mw)
state_dict_optim_mw = load_state_dict(
resolved_archive_file_mw[0], None, expected_keys_mw, ckpt_quant_stage=model.config.ckpt_quant_stage
)

for key in list(state_dict_optim.keys()):
key_name = key.split("/")
Expand Down Expand Up @@ -2008,14 +2061,27 @@
filter_tensor_list = [[] for i in range(tp_size)]

if tp_rank == 0:
quant = False
if model_to_save.config.ckpt_quant_stage != "O0":

Check warning on line 2065 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2064-L2065

Added lines #L2064 - L2065 were not covered by tests
quant = True
tensor_bytes_dict = {}
model_state_dict = get_expected_state_dict(model_to_save)
for (k, v) in state_dict.items():
model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v
if hasattr(model_v, "is_distributed") and model_v.is_distributed:
tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype)
weight_key = k.split("/")[0]
model_v = model_state_dict[weight_key] if is_optimizer else v

Check warning on line 2071 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2071

Added line #L2071 was not covered by tests
if not quant or not is_optimizer:
if hasattr(model_v, "is_distributed") and model_v.is_distributed:
tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype)
else:
tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype)

Check warning on line 2076 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2073-L2076

Added lines #L2073 - L2076 were not covered by tests
else:
tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype)
if weight_key not in tensor_bytes_dict:
tensor_bytes_dict[weight_key] = 0

if hasattr(model_v, "is_distributed") and model_v.is_distributed:
tensor_bytes_dict[weight_key] += v.numel().item() * tp_size * dtype_byte_size(v.dtype)
else:

Check warning on line 2083 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2078-L2083

Added lines #L2078 - L2083 were not covered by tests
tensor_bytes_dict[weight_key] += v.numel().item() * dtype_byte_size(v.dtype)

filter_tensor_list = []
current_block = []
Expand All @@ -2036,7 +2102,14 @@
current_block = []
current_block_size = 0

current_block.append(key)
if not quant or not is_optimizer:
current_block.append(key)

Check warning on line 2106 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2105-L2106

Added lines #L2105 - L2106 were not covered by tests
else:
current_block.append(key + "/" + MOMENT1_KEYNAME)
current_block.append(key + "/" + MOMENT2_KEYNAME)

Check warning on line 2109 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2109

Added line #L2109 was not covered by tests
current_block.append(key + "/" + BETA1_KEYNAME)
current_block.append(key + "/" + BETA2_KEYNAME)

Check warning on line 2112 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2112

Added line #L2112 was not covered by tests
current_block_size += weight_size
total_size += weight_size

Expand Down Expand Up @@ -2307,7 +2380,10 @@
def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization):
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
if not has_master_weight:
if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config:
if (
UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in args.unified_checkpoint_config

Check warning on line 2384 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2383-L2384

Added lines #L2383 - L2384 were not covered by tests
or UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config
):

Check warning on line 2386 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2386

Added line #L2386 was not covered by tests
index_filename_master_weights = (
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
)
Expand All @@ -2319,7 +2395,8 @@
else:
raise ValueError(
"Can't find a valid unified master weight checkpoint,"
f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to "
f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}'"
f" or '{UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value}' into 'unified_checkpoint_config' to "
"load model checkpoint as master weight"
)
else:
Expand Down
Loading
Loading