-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unified Checkpoint] Checkpoint compression #9183
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9183 +/- ##
==========================================
Coverage 52.98% 52.98%
==========================================
Files 676 687 +11
Lines 108003 109184 +1181
==========================================
+ Hits 57220 57851 +631
- Misses 50783 51333 +550 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个单测
@@ -149,11 +165,93 @@ def __init__(self, args): | |||
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) | |||
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) | |||
|
|||
def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"): | |||
def quant_unified_optimizer(self, state_dict, state_dict_type, ckpt_quant_stage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块建议单独拎出来放到一个文件里,目前我正在重构unified_checkpoint.py,会把比较多逻辑分离出来。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddlenlp/trainer/trainer.py
Outdated
@@ -1179,6 +1179,17 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): | |||
self.state.epoch = epoch + (step + 1) / steps_in_epoch | |||
self.control = self.callback_handler.on_step_end(args, self.state, self.control) | |||
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) | |||
if self.state.global_step != 0 and (self.state.global_step) % self.args.save_steps == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方具体是啥?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddlenlp/peft/lora/lora_model.py
Outdated
shard_file, | ||
tp_actions if pre_tensor_parallel_split else None, | ||
expected_keys, | ||
ckpt_quant_stage=model.config.ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块为啥需要传这个ckpt_quant_stage进来,默认O0的话就不用传吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
shard_file, | ||
tp_actions if pre_tensor_parallel_split else None, | ||
expected_keys, | ||
ckpt_quant_stage=model.config.ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -175,6 +186,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty | |||
self._lock, | |||
state_dict_type, | |||
self.global_rank, | |||
ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果只需要对optimizer_weight做压缩,其他例如model_weight、master_weight不用的话,这个变量可以不传入。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -199,6 +211,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty | |||
if "skip_save_model_weight" in self.args.unified_checkpoint_config | |||
else state_dict_type, | |||
self.global_rank, | |||
ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -246,6 +260,7 @@ def _save_file_async_in_process( | |||
lock, | |||
state_dict_type, | |||
global_rank, | |||
ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
搞成一个可选参数就行,例如ckpt_quant_stage="O0"
@@ -326,6 +344,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): | |||
path=os.path.join(save_directory, shard_file), | |||
is_sync=is_sync_save, | |||
state_dict_type="model_weight", | |||
ckpt_quant_stage=model_to_save.config.ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉
) | ||
self._file_save_async_or_sync( | ||
master_weights, | ||
path=os.path.join(output_dir, master_weights_name), | ||
is_sync=is_sync_save, | ||
state_dict_type="master_weight", | ||
ckpt_quant_stage=model.config.ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉
) | ||
if master_weight_state_dict is not None: | ||
self._file_save_async_or_sync( | ||
master_weight_state_dict, | ||
path=os.path.join(save_directory, shard_master_weight_file), | ||
is_sync=is_sync_save, | ||
state_dict_type="master_weight", | ||
ckpt_quant_stage=model.config.ckpt_quant_stage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉
…nto ckpt-compress Conflicts: paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
|
||
returned_state_dict.update(state_dict) | ||
# force memory release | ||
del state_dict | ||
gc.collect() | ||
return returned_state_dict | ||
|
||
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) | ||
index = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一行可以去掉,比较多余
@@ -215,6 +215,11 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp | |||
static_name, type_name = generate_base_static_name(key) | |||
new_name = static2struct_name_mappings[static_name] + "/" + type_name | |||
optim_state_dict[new_name] = optim_state_dict.pop(key) | |||
|
|||
if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in self.args.unified_checkpoint_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REMOVE_MASTER_WEIGHT 这个判断不应该写在这个函数里,应该控制传进来save_non_merge_optimizer的master_weights就是none。
@@ -320,6 +332,100 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype: | |||
return last_dtype | |||
|
|||
|
|||
def dequant_unified_optimizer(self, state_dict, ckpt_quant_stage, scale_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self输入多余,测试过吗
…nto ckpt-compress Conflicts: paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
): | ||
""" | ||
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise. | ||
""" | ||
quant = False | ||
if ckpt_quant_stage != "O0": | ||
quant = "optimizer" in checkpoint_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有点 hack了。
@@ -320,6 +332,100 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype: | |||
return last_dtype | |||
|
|||
|
|||
def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
挪到 ..quantization
文件夹下面?
with safe_open(checkpoint_file, framework="np") as f: | ||
for key in keys: | ||
if fliter_dict_keys is not None and key not in fliter_dict_keys: | ||
# non merge ckpt loading dont have filter key. | ||
if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys): | |
if key.endswith(SYMMETRY_QUANT_SCALE): | |
continue | |
if (fliter_dict_keys is not None and key not in fliter_dict_keys): | |
continue |
paddlenlp/utils/env.py
Outdated
MOMENT2_KEYNAME = "moment2_0" | ||
BETA1_KEYNAME = "beta1_pow_acc_0" | ||
BETA2_KEYNAME = "beta2_pow_acc_0" | ||
SYMMETRY_QUANT_SCALE = "_codebook" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 这个你最好加一些特殊符号,不然容易出现重名。
"- async_save: enable asynchronous saving checkpoints to disk\n" | ||
"- enable_all_options: enable all optimization configurations\n" | ||
) | ||
}, | ||
) | ||
ckpt_quant_stage: str = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看看要不要放到 unifie_checkpoint_config 中配置,因为是搭配UC使用。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -0,0 +1,303 @@ | |||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2020 -> 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
abs_max_values = np.where( | ||
abs_max_values == np.array(0, dtype=inputs.dtype), np.array(1e-8, dtype=inputs.dtype), abs_max_values | ||
) | ||
return abs_max_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接用1e-8来表示是不是没有考虑训练的dtype,bf16、float16、float32 表示空间不太一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group-wise 中一个 group 有可能全是 0,会导致量化时除 0,这里的 1e-8 是防除 0 的一个小偏置
import numpy as np | ||
import paddle | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重要的函数都要加上注释,同时参数的args也需要加上
对于引用的量化算法加上arvix链接
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
|
||
# channel-wise abs max calculation | ||
def cal_abs_max_channel(inputs, quant_axis=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的quant axis 为什么默认是1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic number加上注释
qdq_x = ( | ||
quant_x | ||
/ bnt | ||
* scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个变量名比较奇怪,world_size一般情况下都是指带训练总卡数,但是在这里的表示tensor parallel 通信组的size;注意变量名
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里同时有个疑问,我看是对所有的参数都是做了quant,但是Norm参数没有做参数切分,这个时候还能这么quant吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量加上注释,不然代码的阅读性差
if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: | ||
qdq_x = (quant_x / bnt * scales) + mins | ||
else: | ||
qdq_x = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有些问题同qdq_weight
int4_high = np.where(int4_high > 8, int4_high - 16, int4_high) | ||
|
||
high_tensor = paddle.Tensor(int4_high, zero_copy=True) | ||
low_tensor = paddle.Tensor(int4_low, zero_copy=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的Tensor是放在GPU还是CPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu->gpu,已去除 zero_copy
m1_quant, codebook = qdq_weight(state_dict[m1_key], quant_bit=8) | ||
quant_weight, mins, maxs = asymmetry_qdq_weight(ratio, quant_bit=8) | ||
state_dict[m1_key] = m1_quant | ||
codebook_dict[m1_key + SYMMETRY_QUANT_SCALE] = codebook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的codebook命名来源是什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已统一修改成 scales
dist.all_reduce(quant_bits) | ||
|
||
model_numel = all_bits / 4 | ||
all_bits = model_numel * 7.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic number 写上注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块计算有点多余,已去掉
…nto ckpt-compress
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* checkpoint compression init * add ckpt quant argument * add ckpt quant ci * fix ci * fix lint * remove stage O2, change O3 --> O2 * support async save * file adjustment * magic string remove * ci fix * ci fix, code refinement * function extraction * fix ci * code refinement * fix ci * fix ci * support non merge tp ckpt quantization * fix ci * update * fix bug * code refactor * fix lint * fix ci * del old uc.py * fix lint * add mgpu ci * fix ci * multi thread loading * fix lint * fix bug * refactor code * add comment * fix lint * add comment * add comment * fix bug * fix bugs when ckpt no quant and no master weight * remove uni-test Conflicts: paddlenlp/transformers/model_utils.py
* checkpoint compression init * add ckpt quant argument * add ckpt quant ci * fix ci * fix lint * remove stage O2, change O3 --> O2 * support async save * file adjustment * magic string remove * ci fix * ci fix, code refinement * function extraction * fix ci * code refinement * fix ci * fix ci * support non merge tp ckpt quantization * fix ci * update * fix bug * code refactor * fix lint * fix ci * del old uc.py * fix lint * add mgpu ci * fix ci * multi thread loading * fix lint * fix bug * refactor code * add comment * fix lint * add comment * add comment * fix bug * fix bugs when ckpt no quant and no master weight * remove uni-test Conflicts: paddlenlp/transformers/model_utils.py
PR types
PR changes
Description
checkpoint 压缩功能实现
新增参数