diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 4f619307f5cd..630454a0efe2 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -262,7 +262,9 @@ def from_pretrained(cls, model, lora_path, **kwargs): pre_tensor_parallel_split = True tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True) state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys + shard_file, + tp_actions if pre_tensor_parallel_split else None, + expected_keys, ) error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "") del state_dict diff --git a/paddlenlp/peft/prefix/prefix_model.py b/paddlenlp/peft/prefix/prefix_model.py index 29a34442280c..25d25a354b47 100644 --- a/paddlenlp/peft/prefix/prefix_model.py +++ b/paddlenlp/peft/prefix/prefix_model.py @@ -333,7 +333,9 @@ def from_pretrained( pre_tensor_parallel_split = True tp_actions = prefix_model._get_tensor_parallel_convert_actions(is_split=True) state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys + shard_file, + tp_actions if pre_tensor_parallel_split else None, + expected_keys, ) error_msgs += _load_state_dict_into_model(prefix_model.prefix_encoder, state_dict, "") del state_dict diff --git a/paddlenlp/quantization/checkpoint_quantization_utils.py b/paddlenlp/quantization/checkpoint_quantization_utils.py new file mode 100644 index 000000000000..8541107427df --- /dev/null +++ b/paddlenlp/quantization/checkpoint_quantization_utils.py @@ -0,0 +1,364 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import paddle + + +def cal_ratio(m, v, eps=1e-8): + """ + cal part adam update ratio. + Args: + m (`paddle.Tensor`): + moment in Adam optimizer. + v (`paddle.Tensor`): + variance in Adam optimizer. + eps (`int`): + epsilon in Adam optimizer. + """ + return 1 / (np.sqrt(v) + eps) + + +def group_wise_quant_dequant( + inputs, + mins=None, + maxs=None, + quant_bits=4, + group_size=32, + quant=True, + tp_rank=-1, + tp_degree=1, + use_pd=False, + symmetry=False, +): + """ + group-wise quantization (support symmetry, asymmetry). + Args: + inputs (`paddle.Tensor`): + The tensor to quantize. + mins (`paddle.Tensor`): + Min scales tensor in asymmetry quantization. + maxs (`paddle.Tensor`): + Max scales tensor in asymmetry quantization, or Abs max tensor in symmetry quantization. + quant_bits (`int`): + Quantization bits. + group_size (`int`): + Group size of group-wise quantization. + quant (`bool`): + True when quantization, False in dequantization. + tp_rank (`int`): + Tensor parallel rank. + tp_degree (`int`): + Tensor parallel world size. + use_pd (`bool`): + Whether to use paddle caculation. If False will use numpy. + symmetry (`bool`): + Whether to use symmetry quantization. + """ + + qmax = (1 << (quant_bits)) - 1 + qmin = 0 + shape = inputs.shape + + if quant: + inputs_processed = inputs.reshape([shape[0] // group_size, group_size, shape[1]]) + if symmetry: + bnt = (1 << (quant_bits - 1)) - 1 + scales = np.max(np.abs(inputs_processed), axis=1) + new_scales = np.repeat(scales, repeats=group_size, axis=0) + quant_tensor = np.clip(np.round(inputs / new_scales * bnt), -bnt - 1, bnt) + return quant_tensor.astype("int8"), scales + + # scales: [shape[0] // group_size, shape[1]] + maxs = np.max(inputs_processed, axis=1) + mins = np.min(inputs_processed, axis=1) + scales = maxs - mins + # new_scales: [shape[0], shape[1]] + new_scales = np.repeat(scales, repeats=group_size, axis=0) + new_mins = np.repeat(mins, repeats=group_size, axis=0) + # add eps to avoid devide zero + quant_tensor = np.clip(np.round((inputs - new_mins) / (new_scales) * qmax), qmin, qmax) + quant_tensor = np.nan_to_num(quant_tensor) + return quant_tensor.astype("uint8"), mins, maxs + else: + if symmetry: + scales = mins + bnt = (1 << (quant_bits - 1)) - 1 + if use_pd: + new_scales = paddle.repeat_interleave(scales, group_size, 0) + else: + new_scales = np.repeat(scales, repeats=group_size, axis=0) + + if tp_rank == -1: + dequant_tensor = inputs.astype("float32") * new_scales / bnt + elif len(new_scales.shape) == 0 or inputs.shape[-1] == new_scales.shape[-1]: + # input tensor was row parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + * new_scales[ + tp_rank * new_scales.shape[0] // tp_degree : (tp_rank + 1) * new_scales.shape[0] // tp_degree + ] + / bnt + ) + else: + # input tensor was column parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + * new_scales[ + :, + tp_rank + * new_scales.shape[-1] + // tp_degree : (tp_rank + 1) + * new_scales.shape[-1] + // tp_degree, + ] + / bnt + ) + return dequant_tensor + + scales = maxs - mins + if use_pd: + new_scales = paddle.repeat_interleave(scales, group_size, 0) + new_mins = paddle.repeat_interleave(mins, group_size, 0) + else: + new_scales = np.repeat(scales, repeats=group_size, axis=0) + new_mins = np.repeat(mins, repeats=group_size, axis=0) + + if tp_rank == -1: + dequant_tensor = (inputs.astype("float32") / qmax * new_scales) + new_mins + elif len(new_scales.shape) == 0 or inputs.shape[-1] == new_scales.shape[-1]: + # input tensor was row parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + / qmax + * new_scales[ + tp_rank * new_scales.shape[0] // tp_degree : (tp_rank + 1) * new_scales.shape[0] // tp_degree + ] + ) + new_mins[tp_rank * new_mins.shape[0] // tp_degree : (tp_rank + 1) * new_mins.shape[0] // tp_degree] + else: + # input tensor was column parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + / qmax + * new_scales[ + :, tp_rank * new_scales.shape[-1] // tp_degree : (tp_rank + 1) * new_scales.shape[-1] // tp_degree + ] + ) + new_mins[ + :, tp_rank * new_mins.shape[-1] // tp_degree : (tp_rank + 1) * new_mins.shape[-1] // tp_degree + ] + return dequant_tensor + + +def merge_int4(x, y): + """ + merge 2 signed int4 to 1 int8 + Args: + x (`numpy.array`): + 4bits signed int x. + y (`numpy.array`): + 4bits signed int y. + """ + int4_high = x << 4 + int4_low = y & 0x0F + final = int4_high | int4_low + return final.astype("int8") + + +def split_int8(final): + """ + split an int8 to 2 int4 elems + Args: + final (`numpy.array`): + 8bits signed int. + """ + int4_high = final >> 4 + int4_low = final & 0x0F + + int4_high = np.where(int4_high > 8, int4_high - 16, int4_high) + + high_tensor = paddle.Tensor(int4_high) + low_tensor = paddle.Tensor(int4_low) + + return high_tensor, low_tensor + + +def cal_abs_min_max_channel(inputs, quant_axis=1): + """ + channel-wise min max scales calculation + Args: + inputs (`numpy.array`): + input tensor for quantization. + quant_axis (`int`): + dimension where calulating inputs' abs min and max scales on. + """ + eps = 1e-8 + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != quant_axis]) + abs_max_values = np.max(inputs, axis=reduce_axis) + abs_min_values = np.min(inputs, axis=reduce_axis) + abs_max_values = np.where( + abs_max_values == np.array(0, dtype=inputs.dtype), np.array(eps, dtype=inputs.dtype), abs_max_values + ) + abs_min_values = np.where( + abs_min_values == np.array(0, dtype=inputs.dtype), np.array(eps, dtype=inputs.dtype), abs_min_values + ) + return abs_max_values, abs_min_values + + +def asymmetry_qdq_weight( + x, quant_bit=8, quant_axis=-1, mins=None, maxs=None, dequant=False, tp_rank=-1, tp_degree=1, use_pd=False +): + """ + channel-wise asymmetry quantization + Args: + x (`paddle.Tensor`): + The tensor to quantize. + quant_bits (`int`): + Quantization bits. + quant_axis (`int`): + Scales caculation axis. + mins (`paddle.Tensor`): + Min scales tensor in asymmetry quantization. + maxs (`paddle.Tensor`): + Max scales tensor in asymmetry quantization. + dequant (`bool`): + True when dequantization, False in quantization. + tp_rank (`int`): + Model parallel rank. + tp_degree (`int`): + Model parallel world size. + use_pd (`bool`): + Whether to use paddle caculation. If False will use numpy. + """ + + if mins is None: + maxs, mins = cal_abs_min_max_channel(x) + bnt = (1 << (quant_bit)) - 1 + scales = maxs - mins + if not dequant: + # quant + quant_x = np.clip(np.round((x - mins) / scales * bnt), 0, bnt) + return quant_x.astype(np.uint8), mins, maxs + else: + quant_x = x + # dequant + if not use_pd: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = (quant_x / bnt * scales) + mins + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + ) + mins[tp_rank * mins.shape[0] // tp_degree : (tp_rank + 1) * mins.shape[0] // tp_degree] + return qdq_x.astype(np.float32), scales + else: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = (quant_x / bnt * scales.unsqueeze(0).expand(quant_x.shape)) + mins + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + .unsqueeze(0) + .expand(quant_x.shape) + ) + mins[tp_rank * mins.shape[0] // tp_degree : (tp_rank + 1) * mins.shape[0] // tp_degree] + return qdq_x.astype(paddle.float32), scales + + +def cal_abs_max_channel(inputs, quant_axis=1): + """ + channel-wise abs max calculation + Args: + inputs (`numpy.array`): + input tensor for quantization. + quant_axis (`int`): + dimension where calulating inputs' abs max scales on. + """ + epsilon = 1e-8 + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != quant_axis]) + abs_max_values = np.max(np.abs(inputs), axis=reduce_axis) + # maybe all elements are zero in one group, + # so set the scales from those group to an actual number + # from divide 0. + abs_max_values = np.where( + abs_max_values == np.array(0, dtype=inputs.dtype), np.array(epsilon, dtype=inputs.dtype), abs_max_values + ) + return abs_max_values + + +def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, tp_rank=-1, tp_degree=1, use_pd=False): + """ + channel-wise symmetry quantization + Args: + x (`paddle.Tensor`): + The tensor to quantize. + quant_bits (`int`): + Quantization bits. + quant_axis (`int`): + Scales caculation axis. + scales (`paddle.Tensor`): + Abs max scales tensor in symmetry quantization. + dequant (`bool`): + True when dequantization, False in quantization. + tp_rank (`int`): + Model parallel rank. + tp_degree (`int`): + Model parallel world size. + use_pd (`bool`): + Whether to use paddle caculation. If False will use numpy. + """ + + if scales is None: + scales = cal_abs_max_channel(x) + bnt = (1 << (quant_bit - 1)) - 1 + if not dequant: + # quant + quant_x = np.clip(np.round(x / scales * bnt), -bnt - 1, bnt) + return quant_x.astype(np.int8), scales + else: + quant_x = x + # dequant + if not use_pd: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = quant_x / bnt * scales + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + ) + # fp32 , int8, int, fp32 or fp64 + return qdq_x.astype(np.float32), scales + else: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = quant_x / bnt * scales.unsqueeze(0).expand(quant_x.shape) + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + .unsqueeze(0) + .expand(quant_x.shape) + ) + # fp32 , int8, int, fp32 or fp64 + return qdq_x.astype(paddle.float32), scales diff --git a/paddlenlp/quantization/unified_checkpoint_quantization.py b/paddlenlp/quantization/unified_checkpoint_quantization.py new file mode 100644 index 000000000000..1f1c3ad0c8a1 --- /dev/null +++ b/paddlenlp/quantization/unified_checkpoint_quantization.py @@ -0,0 +1,209 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed import fleet + +from paddlenlp.quantization.checkpoint_quantization_utils import ( + asymmetry_qdq_weight, + cal_ratio, + group_wise_quant_dequant, + merge_int4, + qdq_weight, + split_int8, +) +from paddlenlp.utils.env import ( + ASYMMETRY_QUANT_SCALE_MAX, + ASYMMETRY_QUANT_SCALE_MIN, + MOMENT1_KEYNAME, + MOMENT2_KEYNAME, + SYMMETRY_QUANT_SCALE, +) +from paddlenlp.utils.log import logger + + +def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict): + """ + dequantize unified optimizer state dict. + Args: + state_dict (`dict`): + unified checkpoint optimizer state dict. + ckpt_quant_stage (`str`): + checkpoint quantization stage, chosen in ["O0", "O1", "O2"]. + scale_dict (`int`): + compression checkpoint scale dict. + """ + tp_rank, tp_degree = -1, 1 + if paddle.distributed.get_world_size() > 1: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + tp_rank, tp_degree = tp_group.rank, tp_group.nranks + + if ckpt_quant_stage == "O1": + # set eps + eps = 1e-8 + for quant_key in state_dict.keys(): + is_moment1 = MOMENT1_KEYNAME in quant_key + is_moment2 = MOMENT2_KEYNAME in quant_key + if is_moment1: + # dequant m1 + scale_key = quant_key + SYMMETRY_QUANT_SCALE + weight = state_dict[quant_key] + scales = scale_dict[scale_key] + weight, _ = qdq_weight( + weight, + scales=scales, + quant_bit=8, + dequant=True, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + ) + state_dict[quant_key] = weight + elif is_moment2: + # dequant ratio + weight = state_dict[quant_key] + min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN + max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX + mins, maxs = scale_dict[min_scale_key], scale_dict[max_scale_key] + weight, _ = asymmetry_qdq_weight( + weight, + mins=mins, + maxs=maxs, + quant_bit=8, + dequant=True, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + ) + # cal m2 + weight = paddle.square(1.0 / weight - eps) + state_dict[quant_key] = weight + elif ckpt_quant_stage == "O2": + # set eps + eps = 1e-8 + m1_state_dict = {} + for quant_key in state_dict.keys(): + # not all optimizer weights in O2 stage were quantized to int8, + # the norm-like weights were still remain in float32. + if state_dict[quant_key].dtype != paddle.int8: + logger.info(f"{quant_key} skip.") + continue + # split int8 + weight = state_dict[quant_key] + m1_quant, ratio_quant = split_int8(weight.numpy()) + # dequant ratio + ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN + ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX + m1_scale_key = quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE + m1_scales = scale_dict[m1_scale_key] + ratio_mins, ratio_maxs = scale_dict[ratio_min_scale_key], scale_dict[ratio_max_scale_key] + m1_weight = group_wise_quant_dequant( + m1_quant, + mins=m1_scales, + maxs=None, + quant_bits=4, + quant=False, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + symmetry=True, + ) + ratio_weight = group_wise_quant_dequant( + ratio_quant, + mins=ratio_mins, + maxs=ratio_maxs, + quant_bits=4, + quant=False, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + ) + + ratio_weight = paddle.square(1.0 / ratio_weight - eps) + state_dict[quant_key] = ratio_weight + m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight + state_dict.update(m1_state_dict) + + return state_dict + + +def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async_save=False): + """ + quantize unified optimizer state dict. + Args: + state_dict (`dict`): + unified checkpoint optimizer state dict. + state_dict_type (`str`): + state_dict type, chosen in ["model_weight", "master_weight", "optimizer_weight"]. + ckpt_quant_stage (`str`): + checkpoint quantization stage, chosen in ["O0", "O1", "O2"]. + async_save (`bool`): + whether use async_save. + """ + quant = False + if ckpt_quant_stage != "O0": + quant = True + del_key = [] + if quant and state_dict_type == "optimizer_weight": + scales_dict = {} + opt_keys = state_dict.keys() + for k in opt_keys: + momentum1 = k.endswith(MOMENT1_KEYNAME) + momentum2 = k.endswith(MOMENT2_KEYNAME) + + quant_weight = None + + if ckpt_quant_stage == "O1": + # m1: wint8, 1/(sqrt(m2)+eps): wint8 + if momentum2: + # m1: m1_quant_weight, m2: ratio + m1_key = k.split("/")[0] + "/" + MOMENT1_KEYNAME + ratio = cal_ratio(state_dict[m1_key], state_dict[k]) + m1_quant, scales = qdq_weight(state_dict[m1_key], quant_bit=8) + quant_weight, mins, maxs = asymmetry_qdq_weight(ratio, quant_bit=8) + state_dict[m1_key] = m1_quant + scales_dict[m1_key + SYMMETRY_QUANT_SCALE] = scales + scales_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = mins + scales_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = maxs + elif not momentum1: + quant_weight = state_dict[k] + elif ckpt_quant_stage == "O2": + # m1: bw-wint4, 1/(sqrt(m2)+eps): bw-wint4 + if momentum2: + # skip norm-like parameters + if len(state_dict[k].shape) < 2: + continue + # m1: m1_quant_weight, m2: ratio + m1_key = k.split("/")[0] + "/" + MOMENT1_KEYNAME + ratio = cal_ratio(state_dict[m1_key], state_dict[k]) + m1_quant, m1_scales = group_wise_quant_dequant(state_dict[m1_key], quant_bits=4, symmetry=True) + quant_weight, r_mins, r_maxs = group_wise_quant_dequant(ratio, quant_bits=4) + quant_weight = merge_int4(m1_quant, quant_weight) + scales_dict[m1_key + SYMMETRY_QUANT_SCALE] = m1_scales + scales_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = r_mins + scales_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = r_maxs + del_key.append(m1_key) + elif not momentum1: + quant_weight = state_dict[k] + + if quant_weight is not None: + state_dict[k] = quant_weight + + for k in del_key: + state_dict.pop(k, None) + + state_dict.update(scales_dict) + + return state_dict diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b45542ccc38f..45a2967ccf92 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2702,6 +2702,7 @@ def _save( "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, + "remove_master_weight": "remove_master_weight" in self.args.unified_checkpoint_config, } if os.path.exists( os.path.join(self.args.output_signal_dir, "async_save_info.json") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 30488e960f14..0fc54d52f74d 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -240,13 +240,15 @@ class TrainOutput(NamedTuple): _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") -def _check_checkpoint_files(folder_path, world_size, ignore_save_lr_and_optim, skip_save_model_weight): +def _check_checkpoint_files( + folder_path, world_size, ignore_save_lr_and_optim, skip_save_model_weight, remove_master_weight +): files = os.listdir(folder_path) model_weight_files = [f for f in files if f.startswith(".model_weight")] a = len(model_weight_files) == world_size if not ignore_save_lr_and_optim: b = True - if not skip_save_model_weight: + if not skip_save_model_weight or not remove_master_weight: master_weight_file = [f for f in files if f.startswith(".master_weight")] b = len(master_weight_file) == world_size optimizer_file = [f for f in files if f.startswith(".optimizer_weight")] @@ -282,8 +284,13 @@ def get_last_checkpoint(folder, signal_folder=None, uc_async_save=False): pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) + remove_master_weight = saving_info.get("remove_master_weight", False) if _check_checkpoint_files( - current_signal_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight + current_signal_path, + pre_world_size, + ignore_save_lr_and_optim, + skip_save_model_weight, + remove_master_weight, ): return current_path return diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 25cf62309983..2f5f337e5994 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -858,11 +858,16 @@ class TrainingArguments: "- skip_save_model_weight: do not save model weights when the masters weight exist\n" "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" " 2. if master weights does not exist, convert model weights to master weights when needed\n" + "- remove_master_weight: same with `master_weight_compatible`, use in checkpoint quantization.\n" "- async_save: enable asynchronous saving checkpoints to disk\n" "- enable_all_options: enable all optimization configurations\n" ) }, ) + ckpt_quant_stage: str = field( + default="O0", + metadata={"help": "checkpoint quantization stage."}, + ) ignore_load_lr_and_optim: Optional[bool] = field( default=False, metadata={"help": "whether to ignore load optimizer and scheduler."}, @@ -1660,6 +1665,7 @@ def is_segment_parallel_supported(): if x not in [ "skip_save_model_weight", "master_weight_compatible", + "remove_master_weight", "async_save", "enable_all_options", "ignore_merge_optimizer", diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py index 942ea41508bf..ffe098808c2f 100644 --- a/paddlenlp/trainer/unified_checkpoint/async_handler.py +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -27,6 +27,10 @@ if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file +from paddlenlp.quantization.unified_checkpoint_quantization import ( + quant_unified_optimizer, +) + from .shared_memory_utils import ( _read_state_dict_from_shm, _traverse_copy_to_shm, @@ -69,12 +73,14 @@ def __init__(self, args): self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) def _file_save_async_or_sync( - self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight", ckpt_quant_stage="O0" ): if is_sync: for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): state_dict[k] = state_dict.pop(k).cpu().numpy() + + state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage) safe_save_file(state_dict, path, metadata={"format": "np"}) else: if len(state_dict.keys()) == 0: @@ -155,6 +161,7 @@ def _file_save_async_or_sync( self._lock, state_dict_type, self.global_rank, + ckpt_quant_stage, ), ) self._process_optimizer_weight.start() @@ -185,6 +192,7 @@ def _save_file_async_in_process( lock, state_dict_type, global_rank, + ckpt_quant_stage="O0", ): shm = shared_memory.SharedMemory(name=shm_name) while True: @@ -198,6 +206,9 @@ def _save_file_async_in_process( signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") logger.info(f"Start to async save {path}") state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array + state_dict = quant_unified_optimizer( + state_dict, state_dict_type, ckpt_quant_stage, async_save=True + ) # ckpt quantization safe_save_file(state_dict, path, {"format": "np"}) del state_dict saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py index 459eff7185d1..d1565c7dd933 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_local.py +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -14,6 +14,7 @@ """Unfied checkpoint locally loading functions.""" import gc +import json import os import paddle @@ -183,6 +184,13 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin if len(resolved_archive_file) > 1: resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: + index = json.loads(f.read()) + + ckpt_quant_stage = "O0" + if "ckpt_quant_stage" in index: + ckpt_quant_stage = index["ckpt_quant_stage"] + # update has_master_weights and index_filename_master_weights # 1. if the master weight exists, only has_master_weights is set True and loaded when needed # 2. if master weight does not exist, convert model weight to master weight when needed @@ -204,7 +212,9 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin if len(resolved_archive_file_mw) > 1: resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") - def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + def load_resolved_archive_file( + resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0" + ): returned_state_dict = {} # load optimizer for shard_file in resolved_archive_file: @@ -227,10 +237,22 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + state_dict = load_state_dict( + shard_file, + tp_actions, + expected_keys, + device="expected", + ckpt_quant_stage=ckpt_quant_stage, + ) else: # for pipeline model, we don't need to use tp_actions - state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") + state_dict = load_state_dict( + shard_file, + None, + expected_keys, + device="expected", + ckpt_quant_stage=ckpt_quant_stage, + ) returned_state_dict.update(state_dict) # force memory release @@ -238,7 +260,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected gc.collect() return returned_state_dict - state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) + state_dict_optim = load_resolved_archive_file( + resolved_archive_file, sharded_metadata, expected_keys, ckpt_quant_stage=ckpt_quant_stage + ) if has_master_weights: state_dict_master_weight = load_resolved_archive_file( resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True @@ -246,9 +270,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected # rename optimizer param for key in list(state_dict_optim.keys()): key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] + model_weight_key = key_name[0] + static_name = struct2static_name_mappings[model_weight_key] if has_master_weights: - if model_state_dict[key_name[0]].dtype != paddle.float32: + if model_state_dict[model_weight_key].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) @@ -257,6 +282,12 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected returned_optim_state_dict[key_name] = state_dict_optim.pop(key) returned_optim_state_dict[key_name].name = key_name + # master weight cast (only in remove_master_weight) + if has_master_weights and state_dict_master_weight[model_weight_key].dtype != paddle.float32: + state_dict_master_weight[model_weight_key] = paddle.cast( + state_dict_master_weight[model_weight_key], dtype=paddle.float32 + ) + if has_master_weights: for key in list(state_dict_master_weight.keys()): static_name = struct2static_name_mappings[key] diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 0190529a84e3..72543e038e6a 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -164,6 +164,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None) "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, + "remove_master_weight": "remove_master_weight" in self.args.unified_checkpoint_config, } paddle.save(save_info, os.path.join(save_directory, ".saving_info")) @@ -210,6 +211,7 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp static_name, type_name = generate_base_static_name(key) new_name = static2struct_name_mappings[static_name] + "/" + type_name optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) @@ -237,6 +239,15 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) + # save opt index json if checkpoint quantization is on. + if self.args.ckpt_quant_stage != "O0": + sharded_optim_index = {"ckpt_quant_stage": self.args.ckpt_quant_stage} + optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME + path = os.path.join(output_dir, optimizer_index_name) + if self.args.should_save: + with open(path, "w") as f: + json.dump(sharded_optim_index, f, indent=4) + is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: is_sync_save = False @@ -246,16 +257,18 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", + ckpt_quant_stage=self.args.ckpt_quant_stage, ) - self.async_handler._file_save_async_or_sync( - master_weights, - path=os.path.join(output_dir, master_weights_name), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="master_weight", - ) + if master_weights is not None: + self.async_handler._file_save_async_or_sync( + master_weights, + path=os.path.join(output_dir, master_weights_name), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="master_weight", + ) - def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): + def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"): # init and get optimizer LR_Scheduler returned_optim_state_dict = nested_copy(optimizer.state_dict()) @@ -263,19 +276,25 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name) master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name) - has_master_weights = True if os.path.isfile(master_weights_path) else False + # no quantization & no master weight represent O1 AMP strategy. + is_amp_o1 = True if not os.path.isfile(master_weights_path) and ckpt_quant_stage == "O0" else False model_state_dict = get_expected_state_dict(model) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected") - if has_master_weights: + optimizer_state_dict = load_state_dict( + optimizer_path, None, None, device="expected", ckpt_quant_stage=ckpt_quant_stage + ) + master_weights = {} + # normal AMP O2 + if not is_amp_o1 and os.path.isfile(master_weights_path): master_weights = load_state_dict(master_weights_path, None, None, device="expected") # rename and move to paddle.Tensor for key in list(optimizer_state_dict.keys()): key_name = key.split("/") + model_weight_key = key_name[0] static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: + if not is_amp_o1: if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: @@ -285,7 +304,13 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key) returned_optim_state_dict[key_name].name = key_name - if has_master_weights: + # master weight cast (only in AMP O2 + remove_master_weight) + if not is_amp_o1 and not os.path.isfile(master_weights_path): + master_weights[model_weight_key] = paddle.cast( + model_state_dict[model_weight_key], dtype=paddle.float32 + ) + + if not is_amp_o1: returned_optim_state_dict["master_weights"] = {} for key in list(master_weights.keys()): static_name = struct2static_name_mappings[key] @@ -320,6 +345,10 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): if "LR_Scheduler" in optim_state_dict.keys(): optim_state_dict.pop("LR_Scheduler") + if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in self.args.unified_checkpoint_config: + logger.info("Skip master weight saving.") + master_weights = None + if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir, signal_dir) return @@ -350,6 +379,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", + ckpt_quant_stage=self.args.ckpt_quant_stage, ) if master_weight_state_dict is not None: self.async_handler._file_save_async_or_sync( @@ -391,16 +421,26 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): optim_state_dict = load_single_card_optimizer(model, optimizer, resume_from_checkpoint) return optim_state_dict + index = {} has_merge_optimizer_safetensors = distributed_isfile( os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME) ) + if has_merge_optimizer_safetensors: + with open(os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), "r") as f: + index = json.loads(f.read()) + + ckpt_quant_stage = "O0" + if "ckpt_quant_stage" in index: + ckpt_quant_stage = index["ckpt_quant_stage"] + # If not having merge optimizer, then load non-merge optimizer. - if not has_merge_optimizer_safetensors: + if "weight_map" not in index: if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = self.load_non_merge_optimizer( model, optimizer, resume_from_checkpoint, + ckpt_quant_stage=ckpt_quant_stage, ) return returned_optim_state_dict else: @@ -445,7 +485,7 @@ def unified_checkpoint_into_shards( assert hasattr(model_to_save, "config") state_dict = get_expected_state_dict(model_to_save) - all_filter_keys = filter_params(model_to_save, state_dict) + all_filter_keys = filter_params(model_to_save, state_dict, args) config_to_save = copy.deepcopy(model_to_save.config) @@ -534,6 +574,7 @@ def unified_optimizer_into_shards( static_name, type_name = generate_base_static_name(key) new_name = static2struct_name_mappings[static_name] + "/" + type_name optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) @@ -541,8 +582,8 @@ def unified_optimizer_into_shards( # filter optimizer param if master_weights is not None: - filter_master_keys = filter_params(model, master_weights, is_optimizer=True) - filter_optim_keys = filter_params(model, optim_state_dict, is_optimizer=True) + filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True) + filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True) tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() tp_size = tp_group.nranks @@ -605,6 +646,10 @@ def unified_optimizer_into_shards( use_expert_parallel=args.use_expert_parallel, ) sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) + + if args.should_save and args.ckpt_quant_stage in ["O1", "O2"]: + sharded_optim_index["ckpt_quant_stage"] = args.ckpt_quant_stage + if master_weights is not None: index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( index_master_weight_file, diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index 9bd9fdcc65b7..58e425ca987d 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -32,6 +32,10 @@ from paddlenlp.transformers.utils import dtype_byte_size from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import ( + BETA1_KEYNAME, + BETA2_KEYNAME, + MOMENT1_KEYNAME, + MOMENT2_KEYNAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME, PADDLE_PEFT_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_INDEX_NAME, @@ -72,6 +76,7 @@ class UnifiedCheckpointOption(ExplicitEnum): SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight" MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible" + REMOVE_MASTER_WEIGHT = "remove_master_weight" ASYNC_SAVE = "async_save" IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" @@ -96,7 +101,10 @@ def is_need_master_weight(optimizer, is_fp16_or_bp16): def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization): if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): if not has_master_weight: - if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config: + if ( + UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in args.unified_checkpoint_config + or UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config + ): index_filename_master_weights = ( PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME ) @@ -108,7 +116,8 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali else: raise ValueError( "Can't find a valid unified master weight checkpoint," - f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to " + f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}'" + f" or '{UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value}' into 'unified_checkpoint_config' to " "load model checkpoint as master weight" ) else: @@ -463,7 +472,7 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, return state_dict_to_save -def filter_params(model_to_save, state_dict, is_optimizer=False): +def filter_params(model_to_save, state_dict, args, is_optimizer=False): """ Group according to the size of the tensor, aiming to make the weight size stored on each device as equal as possible. @@ -479,16 +488,34 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): return [list(state_dict.keys())] filter_tensor_list = [[] for _ in range(tp_size)] + is_master_weights = False if tp_rank == 0: + quant = False + if args.ckpt_quant_stage != "O0": + quant = True tensor_bytes_dict = {} model_state_dict = get_expected_state_dict(model_to_save) for (k, v) in state_dict.items(): - model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v - if hasattr(model_v, "is_distributed") and model_v.is_distributed: - tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) + # master weight has same key as model weight + if not is_master_weights and k in model_state_dict: + is_master_weights = True + + weight_key = k.split("/")[0] + model_v = model_state_dict[weight_key] if is_optimizer else v + if not quant or not is_optimizer: + if hasattr(model_v, "is_distributed") and model_v.is_distributed: + tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) + else: + tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) else: - tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) + if weight_key not in tensor_bytes_dict: + tensor_bytes_dict[weight_key] = 0 + + if hasattr(model_v, "is_distributed") and model_v.is_distributed: + tensor_bytes_dict[weight_key] += v.numel().item() * tp_size * dtype_byte_size(v.dtype) + else: + tensor_bytes_dict[weight_key] += v.numel().item() * dtype_byte_size(v.dtype) filter_tensor_list = [] current_block = [] @@ -509,7 +536,14 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): current_block = [] current_block_size = 0 - current_block.append(key) + if not quant or not is_optimizer or is_master_weights: + current_block.append(key) + else: + current_block.append(key + "/" + MOMENT1_KEYNAME) + current_block.append(key + "/" + MOMENT2_KEYNAME) + current_block.append(key + "/" + BETA1_KEYNAME) + current_block.append(key + "/" + BETA2_KEYNAME) + current_block_size += weight_size total_size += weight_size diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index c4e7d2786307..1cc21b7d9f1a 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -54,6 +54,8 @@ from tqdm.auto import tqdm from paddlenlp.utils.env import ( + ASYMMETRY_QUANT_SCALE_MAX, + ASYMMETRY_QUANT_SCALE_MIN, CONFIG_NAME, LEGACY_CONFIG_NAME, PADDLE_WEIGHTS_INDEX_NAME, @@ -64,10 +66,12 @@ SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, + SYMMETRY_QUANT_SCALE, ) from paddlenlp.utils.log import logger from ..generation import GenerationConfig, GenerationMixin +from ..quantization.unified_checkpoint_quantization import dequant_unified_optimizer from ..utils import device_guard from ..utils.download import resolve_file_path from .configuration_utils import PretrainedConfig @@ -362,10 +366,21 @@ def _load_part_state_dict( """ part_state_dict = {} + scale_dict = {} with safe_open(checkpoint_file, framework="np") as f: for key in keys: + # 1. non-merge ckpt loading dont have filter key. + # 2. merge ckpt will skip quant scale by `fliter_dict_keys` + if ( + key.endswith(SYMMETRY_QUANT_SCALE) + or key.endswith(ASYMMETRY_QUANT_SCALE_MIN) + or key.endswith(ASYMMETRY_QUANT_SCALE_MAX) + ): + continue + if fliter_dict_keys is not None and key not in fliter_dict_keys: continue + py_safe_slice_ = f.get_slice(key) if key in tensor_parallel_split_mapping: weight = tensor_parallel_split_mapping[key](py_safe_slice_) @@ -376,15 +391,31 @@ def _load_part_state_dict( weight = paddle.Tensor(weight, zero_copy=True) weight = weight._copy_to(paddle.framework._current_expected_place(), False) part_state_dict[key] = weight - return part_state_dict + for key in keys: + if ( + key.endswith(SYMMETRY_QUANT_SCALE) + or key.endswith(ASYMMETRY_QUANT_SCALE_MIN) + or key.endswith(ASYMMETRY_QUANT_SCALE_MAX) + ): + scale = f.get_tensor(key) + with device_guard(): + scale = paddle.Tensor(scale, zero_copy=True) + scale = scale._copy_to(paddle.framework._current_expected_place(), False) + scale_dict[key] = scale + return part_state_dict, scale_dict def load_state_dict( - checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu" + checkpoint_file: Union[str, os.PathLike], + tensor_parallel_split_mapping=None, + fliter_dict_keys=None, + device="cpu", + ckpt_quant_stage="O0", ): """ Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise. """ + if tensor_parallel_split_mapping is None: tensor_parallel_split_mapping = {} @@ -404,10 +435,9 @@ def load_state_dict( raise ValueError("Currently unsupport paddle weights file, use numpy instead.") if metadata.get("format", "np") == "np": thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1")) - state_dict = {} if thread_num <= 1: with safe_open(checkpoint_file, framework="np") as f: - state_dict = _load_part_state_dict( + state_dict, scale_dict = _load_part_state_dict( list(f.keys()), checkpoint_file, tensor_parallel_split_mapping, @@ -431,14 +461,20 @@ def load_state_dict( for keys in keys_groups } for future in concurrent.futures.as_completed(future_to_key): - result = future.result() - state_dict.update(result) + state_dict, scale_dict = future.result() + state_dict.update(state_dict) + scale_dict.update(scale_dict) if device == "cpu": for k in list(state_dict.keys()): with device_guard(): state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True) + if len(scale_dict) != 0: + if ckpt_quant_stage == "O0": + raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"') + state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict) + return state_dict state_dict = paddlenlp_load(checkpoint_file, map_location="cpu") @@ -2102,7 +2138,9 @@ def _fuse_or_split_keys( if config.quantization_config.is_weight_quantize(): filter_dict_keys = None state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys + shard_file, + tp_actions if pre_tensor_parallel_split else None, + filter_dict_keys, ) # convert for fusing or splitting weights diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index d1fbbb1a60ba..7ece736c537b 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -111,3 +111,12 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: SAFE_PEFT_WEIGHTS_NAME = "peft_model.safetensors" SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" + +# Checkpoint quantization +MOMENT1_KEYNAME = "moment1_0" +MOMENT2_KEYNAME = "moment2_0" +BETA1_KEYNAME = "beta1_pow_acc_0" +BETA2_KEYNAME = "beta2_pow_acc_0" +SYMMETRY_QUANT_SCALE = "@scales" +ASYMMETRY_QUANT_SCALE_MIN = "@min_scales" +ASYMMETRY_QUANT_SCALE_MAX = "@max_scales" diff --git a/tests/fixtures/llm/finetune.yaml b/tests/fixtures/llm/finetune.yaml index 7e79f9b441a8..abe9aad5d39e 100644 --- a/tests/fixtures/llm/finetune.yaml +++ b/tests/fixtures/llm/finetune.yaml @@ -63,4 +63,53 @@ inference-infer: dtype: float16 batch_size: 2 decode_strategy: greedy_search - max_length: 20 \ No newline at end of file + max_length: 20 + +ckpt_quant: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-05 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "steps" + save_strategy: "steps" + save_steps: 1 + max_steps: 1 + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + ckpt_quant_stage: "O1" + do_train: true + do_eval: true + use_flash_attention: true + unified_checkpoint: true + unified_checkpoint_config: "async_save remove_master_weight" + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + tensor_parallel_degree: 2 + pipeline_parallel_degree: 1 + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 diff --git a/tests/llm/test_finetune.py b/tests/llm/test_finetune.py index d1fda6e67b5f..76c88a3c6a94 100644 --- a/tests/llm/test_finetune.py +++ b/tests/llm/test_finetune.py @@ -18,6 +18,7 @@ from parameterized import parameterized_class +from tests.parallel_launch import TestMultipleGpus from tests.testing_utils import argv_context_guard, load_test_config from .testing_utils import LLMTest @@ -63,3 +64,38 @@ def test_finetune(self): self.run_predictor({"inference_model": True}) self.run_predictor({"inference_model": False}) + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ], +) +class CkptQuantTest(LLMTest, TestMultipleGpus): + config_path: str = "./tests/fixtures/llm/finetune.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + sys.path.insert(0, self.model_dir) + self.run_sft = "llm/run_finetune.py" + + def tearDown(self) -> None: + LLMTest.tearDown(self) + + def test_ckpt_quant(self): + finetune_config = load_test_config(self.config_path, "ckpt_quant", self.model_dir) + + finetune_config["dataset_name_or_path"] = self.data_dir + finetune_config["output_dir"] = self.output_dir + + self.runfirst(finetune_config) + self.rerun(finetune_config) + + def runfirst(self, train_args): + self.run_n1c2(self.run_sft, **train_args) + + def rerun(self, train_args): + self.run_n1c2(self.run_sft, **train_args)