Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Checkpoint compression #9183

Merged
merged 44 commits into from
Nov 25, 2024

Conversation

wtmlon
Copy link
Collaborator

@wtmlon wtmlon commented Sep 23, 2024

PR types

PR changes

Description

checkpoint 压缩功能实现
新增参数

  • --ckpt_quant_stage "O0"/"O1"/"O2"
  • O0:不压缩
    
  • O1:channel-wise int8 压缩
    
  • O2:group-wise int4 压缩 
    
  • --unified_checkpoint_config "remove_master_weight"
  • amp O2开启此 flag 不额外保存master weight权重
    
  • 如果开启此 flag 去载入有 master weight 的 checkpoint,依旧会正常读取 master weight 进行载入
    

@CLAassistant
Copy link

CLAassistant commented Sep 23, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Sep 23, 2024

Codecov Report

Attention: Patch coverage is 11.94969% with 280 lines in your changes missing coverage. Please review.

Project coverage is 52.98%. Comparing base (2ecf7ef) to head (b2bcf16).
Report is 28 commits behind head on develop.

Files with missing lines Patch % Lines
...enlp/quantization/checkpoint_quantization_utils.py 9.17% 99 Missing ⚠️
...lp/quantization/unified_checkpoint_quantization.py 7.29% 89 Missing ⚠️
...p/trainer/unified_checkpoint/unified_checkpoint.py 2.85% 34 Missing ⚠️
paddlenlp/trainer/unified_checkpoint/utils.py 7.69% 24 Missing ⚠️
paddlenlp/trainer/unified_checkpoint/load_local.py 6.66% 14 Missing ⚠️
paddlenlp/transformers/model_utils.py 40.00% 12 Missing ⚠️
paddlenlp/peft/lora/lora_model.py 0.00% 3 Missing ⚠️
paddlenlp/trainer/trainer_utils.py 33.33% 2 Missing ⚠️
...dlenlp/trainer/unified_checkpoint/async_handler.py 33.33% 2 Missing ⚠️
paddlenlp/trainer/training_args.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           develop    #9183     +/-   ##
==========================================
  Coverage    52.98%   52.98%             
==========================================
  Files          676      687     +11     
  Lines       108003   109184   +1181     
==========================================
+ Hits         57220    57851    +631     
- Misses       50783    51333    +550     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@wtmlon wtmlon requested a review from DesmonDay September 24, 2024 09:48
@ZHUI ZHUI changed the title checkpoint compression init [Unified Checkpoint] Checkpoint compression init Sep 30, 2024
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个单测

@wtmlon wtmlon changed the title [Unified Checkpoint] Checkpoint compression init [Unified Checkpoint] Checkpoint compression Oct 14, 2024
@@ -149,11 +165,93 @@ def __init__(self, args):
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
def quant_unified_optimizer(self, state_dict, state_dict_type, ckpt_quant_stage):
Copy link
Contributor

@DesmonDay DesmonDay Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块建议单独拎出来放到一个文件里,目前我正在重构unified_checkpoint.py,会把比较多逻辑分离出来。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1179,6 +1179,17 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
if self.state.global_step != 0 and (self.state.global_step) % self.args.save_steps == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方具体是啥?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

shard_file,
tp_actions if pre_tensor_parallel_split else None,
expected_keys,
ckpt_quant_stage=model.config.ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块为啥需要传这个ckpt_quant_stage进来,默认O0的话就不用传吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

shard_file,
tp_actions if pre_tensor_parallel_split else None,
expected_keys,
ckpt_quant_stage=model.config.ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -175,6 +186,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._lock,
state_dict_type,
self.global_rank,
ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只需要对optimizer_weight做压缩,其他例如model_weight、master_weight不用的话,这个变量可以不传入。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -199,6 +211,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
if "skip_save_model_weight" in self.args.unified_checkpoint_config
else state_dict_type,
self.global_rank,
ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -246,6 +260,7 @@ def _save_file_async_in_process(
lock,
state_dict_type,
global_rank,
ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

搞成一个可选参数就行,例如ckpt_quant_stage="O0"

@@ -326,6 +344,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
path=os.path.join(save_directory, shard_file),
is_sync=is_sync_save,
state_dict_type="model_weight",
ckpt_quant_stage=model_to_save.config.ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉

)
self._file_save_async_or_sync(
master_weights,
path=os.path.join(output_dir, master_weights_name),
is_sync=is_sync_save,
state_dict_type="master_weight",
ckpt_quant_stage=model.config.ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉

)
if master_weight_state_dict is not None:
self._file_save_async_or_sync(
master_weight_state_dict,
path=os.path.join(save_directory, shard_master_weight_file),
is_sync=is_sync_save,
state_dict_type="master_weight",
ckpt_quant_stage=model.config.ckpt_quant_stage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉

…nto ckpt-compress

Conflicts:
	paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

returned_state_dict.update(state_dict)
# force memory release
del state_dict
gc.collect()
return returned_state_dict

state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys)
index = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行可以去掉,比较多余

@@ -215,6 +215,11 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
static_name, type_name = generate_base_static_name(key)
new_name = static2struct_name_mappings[static_name] + "/" + type_name
optim_state_dict[new_name] = optim_state_dict.pop(key)

if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in self.args.unified_checkpoint_config:
Copy link
Contributor

@DesmonDay DesmonDay Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REMOVE_MASTER_WEIGHT 这个判断不应该写在这个函数里,应该控制传进来save_non_merge_optimizer的master_weights就是none。

@@ -320,6 +332,100 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
return last_dtype


def dequant_unified_optimizer(self, state_dict, ckpt_quant_stage, scale_dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self输入多余,测试过吗

…nto ckpt-compress

Conflicts:
	paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
):
"""
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
"""
quant = False
if ckpt_quant_stage != "O0":
quant = "optimizer" in checkpoint_file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有点 hack了。

@@ -320,6 +332,100 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
return last_dtype


def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

挪到 ..quantization 文件夹下面?

with safe_open(checkpoint_file, framework="np") as f:
for key in keys:
if fliter_dict_keys is not None and key not in fliter_dict_keys:
# non merge ckpt loading dont have filter key.
if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys):
if key.endswith(SYMMETRY_QUANT_SCALE):
continue
if (fliter_dict_keys is not None and key not in fliter_dict_keys):
continue

MOMENT2_KEYNAME = "moment2_0"
BETA1_KEYNAME = "beta1_pow_acc_0"
BETA2_KEYNAME = "beta2_pow_acc_0"
SYMMETRY_QUANT_SCALE = "_codebook"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 这个你最好加一些特殊符号,不然容易出现重名。

"- async_save: enable asynchronous saving checkpoints to disk\n"
"- enable_all_options: enable all optimization configurations\n"
)
},
)
ckpt_quant_stage: str = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看看要不要放到 unifie_checkpoint_config 中配置,因为是搭配UC使用。

DesmonDay
DesmonDay previously approved these changes Nov 15, 2024
Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ZHUI
ZHUI previously approved these changes Nov 15, 2024
@@ -0,0 +1,303 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2020 -> 2024

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

abs_max_values = np.where(
abs_max_values == np.array(0, dtype=inputs.dtype), np.array(1e-8, dtype=inputs.dtype), abs_max_values
)
return abs_max_values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接用1e-8来表示是不是没有考虑训练的dtype,bf16、float16、float32 表示空间不太一样

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group-wise 中一个 group 有可能全是 0,会导致量化时除 0,这里的 1e-8 是防除 0 的一个小偏置

import numpy as np
import paddle


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重要的函数都要加上注释,同时参数的args也需要加上
对于引用的量化算法加上arvix链接

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



# channel-wise abs max calculation
def cal_abs_max_channel(inputs, quant_axis=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的quant axis 为什么默认是1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic number加上注释

qdq_x = (
quant_x
/ bnt
* scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量名比较奇怪,world_size一般情况下都是指带训练总卡数,但是在这里的表示tensor parallel 通信组的size;注意变量名

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里同时有个疑问,我看是对所有的参数都是做了quant,但是Norm参数没有做参数切分,这个时候还能这么quant吗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量加上注释,不然代码的阅读性差

if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]:
qdq_x = (quant_x / bnt * scales) + mins
else:
qdq_x = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有些问题同qdq_weight

int4_high = np.where(int4_high > 8, int4_high - 16, int4_high)

high_tensor = paddle.Tensor(int4_high, zero_copy=True)
low_tensor = paddle.Tensor(int4_low, zero_copy=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的Tensor是放在GPU还是CPU

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu->gpu,已去除 zero_copy

m1_quant, codebook = qdq_weight(state_dict[m1_key], quant_bit=8)
quant_weight, mins, maxs = asymmetry_qdq_weight(ratio, quant_bit=8)
state_dict[m1_key] = m1_quant
codebook_dict[m1_key + SYMMETRY_QUANT_SCALE] = codebook
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的codebook命名来源是什么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已统一修改成 scales

dist.all_reduce(quant_bits)

model_numel = all_bits / 4
all_bits = model_numel * 7.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic number 写上注释

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块计算有点多余,已去掉

@wtmlon wtmlon dismissed stale reviews from ZHUI and DesmonDay via a6b2236 November 19, 2024 09:50
wawltor
wawltor previously approved these changes Nov 22, 2024
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 195fde3 into PaddlePaddle:develop Nov 25, 2024
9 of 12 checks passed
wtmlon added a commit to wtmlon/PaddleNLP that referenced this pull request Nov 25, 2024
* checkpoint compression init

* add ckpt quant argument

* add ckpt quant ci

* fix ci

* fix lint

* remove stage O2, change O3 --> O2

* support async save

* file adjustment

* magic string remove

* ci fix

* ci fix, code refinement

* function extraction

* fix ci

* code refinement

* fix ci

* fix ci

* support non merge tp ckpt quantization

* fix ci

* update

* fix bug

* code refactor

* fix lint

* fix ci

* del old uc.py

* fix lint

* add mgpu ci

* fix ci

* multi thread loading

* fix lint

* fix bug

* refactor code

* add comment

* fix lint

* add comment

* add comment

* fix bug

* fix bugs when ckpt no quant and no master weight

* remove uni-test
Conflicts:
	paddlenlp/transformers/model_utils.py
DesmonDay pushed a commit that referenced this pull request Nov 25, 2024
* checkpoint compression init

* add ckpt quant argument

* add ckpt quant ci

* fix ci

* fix lint

* remove stage O2, change O3 --> O2

* support async save

* file adjustment

* magic string remove

* ci fix

* ci fix, code refinement

* function extraction

* fix ci

* code refinement

* fix ci

* fix ci

* support non merge tp ckpt quantization

* fix ci

* update

* fix bug

* code refactor

* fix lint

* fix ci

* del old uc.py

* fix lint

* add mgpu ci

* fix ci

* multi thread loading

* fix lint

* fix bug

* refactor code

* add comment

* fix lint

* add comment

* add comment

* fix bug

* fix bugs when ckpt no quant and no master weight

* remove uni-test
Conflicts:
	paddlenlp/transformers/model_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants