Skip to content

Commit

Permalink
mergekit_with_sparsify (#9561)
Browse files Browse the repository at this point in the history
* mergekit_with_sparsify

* mergekit_with_sparsify

* 参数上传修改

* change sparsify style

* test

* add merge

* add merge

* add merge

* add merge

* add test

* add tensor

* add merge

* add del

* add merge

* add merge

* del base

* del base

* alpha

* update follow comments

---------

Co-authored-by: lugimzzz <zhenglujing@baidu.com>
  • Loading branch information
Mangodadada and lugimzzz authored Dec 18, 2024
1 parent 407b3e6 commit fa0febc
Show file tree
Hide file tree
Showing 17 changed files with 1,259 additions and 0 deletions.
35 changes: 35 additions & 0 deletions llm/tools/merge_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

from paddlenlp.mergekit import MergeConfig, MergeModel
from paddlenlp.trainer import PdArgumentParser
from paddlenlp.utils.log import logger


def merge():
parser = PdArgumentParser((MergeConfig))
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
merge_config = parser.parse_json_file_and_cmd_lines()[0]
else:
merge_config = parser.parse_args_into_dataclasses()[0]

mergekit = MergeModel(merge_config)
logger.info("Start to merge model.")
mergekit.merge_model()
logger.info("Finish merging model.")


if __name__ == "__main__":
merge()
1 change: 1 addition & 0 deletions paddlenlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
experimental,
layers,
losses,
mergekit,
metrics,
ops,
peft,
Expand Down
19 changes: 19 additions & 0 deletions paddlenlp/mergekit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you smay not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .merge_config import *
from .merge_method import *
from .merge_model import *
from .merge_utils import *
from .sparsify_method import *
168 changes: 168 additions & 0 deletions paddlenlp/mergekit/merge_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from dataclasses import asdict, dataclass, field
from typing import List, Optional

import paddle

from paddlenlp.utils.env import MERGE_CONFIG_NAME
from paddlenlp.utils.log import logger


@dataclass
class MergeConfig:
"""
This is the configuration class to store the configuration of a [`MergeKit`].
"""

# Common parameters
device: str = field(default="cpu", metadata={"help": "Device to use for the merge.ex cpu、 gpu、low_gpu_mem"})
tensor_type: str = field(
default="np", metadata={"help": "Tensor type to use for the merge. Choose np(CPU Only) or pd (CPU/GPU)"}
)
n_process: int = field(default=1, metadata={"help": "Number of processes to use for the merge."})
merge_preifx: str = field(default="model", metadata={"help": "Prefix name: model or master_weights"})
merge_method: str = field(default="linear", metadata={"help": "The merge strategy."})
merge_type: str = field(default="linear", metadata={"help": "The type of merge process."})
sparsify_type: str = field(default=None, metadata={"help": "The type of sparsify process."})

# Model parameters
model_path_list: Optional[List[str]] = field(default=None, metadata={"help": "Merge model name or path list"})
model_path_str: Optional[str] = field(
default=None, metadata={"help": "Merge model name or path string.(split by ',')"}
)
base_model_path: str = field(default=None, metadata={"help": "Base model name or path."})
output_path: str = field(default=None, metadata={"help": "Base model name or path."})
# merge parameters
weight_list: Optional[List[float]] = field(
default=None, metadata={"help": "Relative (or absolute if normalize=False) weighting of a given tensor"}
)
normalize: bool = field(default=True, metadata={"help": "Whether to normalize the weighting."})
slerp_alpha: float = field(default=0.5, metadata={"help": "Slerp alpha."})
slerp_normalize_eps: float = field(default=1e-8, metadata={"help": "Slerp normalization epsilon value"})
slerp_dot_threshold: float = field(
default=0.9995,
metadata={
"help": "Slerp dot threshold. If dot value exceeds this threshold, then we consider them as colinear, so use linear instead."
},
)
ties_elect_type: str = field(default="sum", metadata={"help": "The type of ties mask. 'sum' or 'count'"})

# Sparsify parameters
rescale: bool = field(default=True, metadata={"help": "Rescale the weights after sparsifying."})
reserve_p: float = field(default=0.7, metadata={"help": "Random reserve probability for the sparsify model."})
epsilon: float = field(default=0.14, metadata={"help": "Epsilon value for magprune."})

def __post_init__(self):
self.config_check()

def config_check(self):
if self.output_path is not None:
os.makedirs(self.output_path, exist_ok=True)
if self.tensor_type not in ["np"]:
raise ValueError(f"Unsupported tensor type: {self.tensor_type}. Support 'np' only.")
if self.device != "cpu":
logger.warning(f"Currently only support cpu device, but got {self.device}. Setting `device` to `cpu`.")
self.device = "cpu"
self.tensor_type = "np"

elif self.merge_method not in ["linear", "ties", "slerp", "della_linear", "della", "dare_linear", "dare_ties"]:
raise ValueError(
f"Unsupported merge strategy: {self.merge_method}. Please choose one from ['linear', 'slerp']."
)
if self.model_path_str is not None:
self.model_path_list = self.model_path_str.split(",")
if self.model_path_list is not None:
if not isinstance(self.model_path_list, list) or len(self.model_path_list) < 2:
raise ValueError(f"Please specify the model_path_list at least two. But got {self.model_path_list}")
if self.weight_list is None:
self.weight_list = [1.0] * len(self.model_path_list)
self.normalize = True
if len(self.model_path_list) != len(self.weight_list):
raise ValueError("The length of model_path_list and weight_list must be the same.")
if self.reserve_p < 0 or self.reserve_p > 1:
raise ValueError("reserve_p must be between 0 and 1.")
if "della" in self.merge_method or self.sparsify_type == "magprune":
if self.reserve_p <= self.epsilon / 2 or self.reserve_p >= (1 - self.epsilon):
raise ValueError(
f"Error: reserve_p +- epsilon/2 must be in the range (0, 1). reserve_p + epsilon/2 = {self.reserve_p + self.epsilon / 2 }, reserve_p - epsilon/2 = {self.reserve_p - self.epsilon / 2 }"
)
paddle.set_device(self.device)

@property
def __dict__(self):
return asdict(self)

def to_dict(self):
return self.__dict__

def save_pretrained(self, save_directory):
r"""
This method saves the configuration of your adapter model in a directory.
Args:
save_directory (`str`):
The directory where the configuration will be saved.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

os.makedirs(save_directory, exist_ok=True)

output_dict = self.__dict__
output_path = os.path.join(save_directory, MERGE_CONFIG_NAME)

# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))

@classmethod
def from_pretrained(cls, pretrained_model_path, **kwargs):
r"""
This method loads the configuration of your adapter model from a directory.
Args:
pretrained_model_path (`str`):
The directory or the hub-id where the configuration is saved.
**kwargs:
Additional keyword arguments passed along to the child class initialization.
"""
if os.path.isfile(os.path.join(pretrained_model_path, MERGE_CONFIG_NAME)):
config_file = os.path.join(pretrained_model_path, MERGE_CONFIG_NAME)
else:
raise ValueError(f"Can't find merge_config.json at '{pretrained_model_path}'")

loaded_attributes = cls.from_json_file(config_file)

config = cls(**kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)

return config

@classmethod
def from_json_file(cls, path_json_file):
r"""
Loads a configuration file from a json file.
Args:
path_json_file (`str`):
The path to the json file.
"""
with open(path_json_file, "r") as file:
json_object = json.load(file)

return json_object
130 changes: 130 additions & 0 deletions paddlenlp/mergekit/merge_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


class MergeMethod:
def __init__(self, merge_config, sparsify_method=None):
self.merge_config = merge_config
self.sparsify_method = sparsify_method

def merge(self, tensor_list):
if self.sparsify_method is not None:
tensor_list = [self.sparsify_method.sparsify(tensor) for tensor in tensor_list]
if self.merge_config.merge_type == "linear":
return self.linear(tensor_list)
elif self.merge_config.merge_type == "slerp":
return self.slerp(tensor_list)
elif self.merge_config.merge_type == "ties":
return self.ties(tensor_list)
else:
raise NotImplementedError(f"{self.merge_config.merge_type} is not supported yet.")

def linear(self, tensor_list):
"""
Linear interpolation between multiple values.
"""
# init weight
weight_list = self.merge_config.weight_list
if self.merge_config.normalize:
weight_sum = sum(weight_list)
weight_list = [weight / weight_sum for weight in weight_list]

# merge
if self.merge_config.tensor_type == "np":
tensor_output = sum(weight * tensor for weight, tensor in zip(weight_list, tensor_list))
return tensor_output
else:
raise NotImplementedError("Paddle Tensor is not supported yet.")

def slerp(self, tensor_list):
"""
Spherical linear interpolation
"""
# check tensor_list length
if len(tensor_list) != 2:
raise ValueError("Slerp only support two tensors merge.")

if self.merge_config.tensor_type == "np":
t0, t1 = tensor_list
# Copy the vectors to reuse them later
t0_copy = np.copy(t0)
t1_copy = np.copy(t1)

# Normalize the vectors to get the directions and angles
t0 = self.normalize(t0)
t1 = self.normalize(t1)

# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(t0 * t1)
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
if np.abs(dot) > self.merge_config.slerp_dot_threshold:
return (1 - self.merge_config.slerp_alpha) * t0_copy + self.merge_config.slerp_alpha * t1_copy

# Calculate initial angle between t0 and t1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)

# Angle at timestep t
theta_t = theta_0 * self.merge_config.slerp_alpha
sin_theta_t = np.sin(theta_t)

# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0

return s0 * t0_copy + s1 * t1_copy
else:
raise NotImplementedError("Paddle Tensor is not supported yet.")

def ties(self, tensor_list):
if self.merge_config.tensor_type == "np":
# Get weight tensor
mask_dtype = tensor_list[0].dtype
weight_list = self.merge_config.weight_list
tensor_list = [weight * tensor for (weight, tensor) in zip(weight_list, tensor_list)]

# Elect majority sign
sign_tensor_list = [np.sign(tensor).astype(mask_dtype) for tensor in tensor_list]
if self.merge_config.ties_elect_type == "sum":
majority_sign = (np.sum(tensor_list, axis=0) >= 0).astype(mask_dtype) * 2 - 1
elif self.merge_config.ties_elect_type == "count":
majority_sign = (np.sum(sign_tensor_list, axis=0) >= 0).astype(mask_dtype) * 2 - 1
else:
raise NotImplementedError(f"ties_elect_type: {self.merge_config.ties_elect_type} is unknown.")

# Merge
mask_list = [sign_tensor == majority_sign for sign_tensor in sign_tensor_list]
tensor_list = [mask * tensor for mask, tensor in zip(mask_list, tensor_list)]
merge_tensor = np.sum(tensor_list, axis=0)

# Normalize
if self.merge_config.normalize:
weight_mask = [mask * weight for mask, weight in zip(mask_list, weight_list)]
divisor = np.sum(weight_mask, axis=0)
divisor[np.abs(divisor) < 1e-8] = 1
merge_tensor /= divisor
return merge_tensor
else:
raise NotImplementedError("Paddle Tensor is not supported yet.")

def normalize(self, t):
"""
Normalize a vector by its L2 norm.
"""
norm_t = np.linalg.norm(t)
if norm_t > self.merge_config.slerp_normalize_eps:
t = t / norm_t
return t
Loading

0 comments on commit fa0febc

Please sign in to comment.