Skip to content

Commit

Permalink
add magnitude_prune merging method (#1466)
Browse files Browse the repository at this point in the history
* add `magnitude_prune` merging method

* Update model.py

* 😅
  • Loading branch information
pacman100 authored Feb 15, 2024
1 parent 83de1af commit 25dec60
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_get_submodules,
get_quantization_config,
)
from peft.utils.merge_utils import dare_linear, dare_ties, task_arithmetic, ties
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties

from .config import LoraConfig
from .gptq import dispatch_gptq
Expand Down Expand Up @@ -395,9 +395,9 @@ def add_weighted_adapter(
Name of the new adapter.
combination_type (`str`):
The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
`dare_ties_svd`, `dare_linear_svd`]. When using the `cat` combination_type, the rank of the resulting
adapter is equal to the sum of all adapters ranks (the mixed adapter may be too big and result in OOM
errors).
`dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
mixed adapter may be too big and result in OOM errors).
svd_rank (`int`, *optional*):
Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
svd_clamp (`float`, *optional*):
Expand All @@ -412,7 +412,8 @@ def add_weighted_adapter(
documentation. Defaults to None.
density (`float`, *optional*):
Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used
with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`]
with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`,
`magnintude_prune`, `magnitude_prune_svd`]
majority_sign_method (`str`):
The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values.
Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`]
Expand All @@ -428,7 +429,7 @@ def add_weighted_adapter(
combination_type = "linear" if len(adapters) == 1 else combination_type

adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear"):
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
Expand Down Expand Up @@ -509,7 +510,13 @@ def add_weighted_adapter(
loras_B = torch.cat(loras_B, dim=1)
target_lora_A.data[: loras_A.shape[0], :] = loras_A
target_lora_B.data[:, : loras_B.shape[1]] = loras_B
elif combination_type in ["svd", "ties_svd", "dare_linear_svd", "dare_ties_svd"]:
elif combination_type in [
"svd",
"ties_svd",
"dare_linear_svd",
"dare_ties_svd",
"magnitude_prune_svd",
]:
target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter(
combination_type,
adapters,
Expand All @@ -524,7 +531,7 @@ def add_weighted_adapter(
full_matrices=svd_full_matrices,
driver=svd_driver,
)
elif combination_type in ["linear", "ties", "dare_linear", "dare_ties"]:
elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]:
target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter(
combination_type, adapters, weights, target, density, majority_sign_method
)
Expand Down Expand Up @@ -565,6 +572,8 @@ def _svd_generalized_task_arithmetic_weighted_adapter(
delta_weight = dare_linear(delta_weight, valid_weights, density)
elif combination_type == "dare_ties_svd":
delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method)
elif combination_type == "magnitude_prune_svd":
delta_weight = magnitude_prune(delta_weight, valid_weights, density)
else:
raise ValueError(f"Invalid value passed to combination type: {combination_type}")

Expand Down Expand Up @@ -632,6 +641,8 @@ def _generalized_task_arithmetic_weighted_adapter(
lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
elif combination_type == "dare_ties":
lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method)
elif combination_type == "magnitude_prune":
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
else:
raise ValueError("Invalid combination type")
lora_deltas = [delta.to(dtype) for delta in lora_deltas]
Expand Down
22 changes: 22 additions & 0 deletions src/peft/utils/merge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ def task_arithmetic(task_tensors: List[torch.Tensor], weights: torch.Tensor) ->
return mixed_task_tensors


def magnitude_prune(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor:
"""
Merge the task tensors using `task arithmetic`.
Args:
task_tensors(`List[torch.Tensor]`):The task tensors to merge.
weights (`torch.Tensor`):The weights of the task tensors.
density (`float`): The fraction of values to preserve. Should be in [0,1].
Returns:
`torch.Tensor`: The merged tensor.
"""
# sparsify
task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors]
task_tensors = torch.stack(task_tensors, dim=0)
# weighted task tensors
weights = reshape_weight_task_tensors(task_tensors, weights)
weighted_task_tensors = task_tensors * weights
mixed_task_tensors = weighted_task_tensors.sum(dim=0)
return mixed_task_tensors


def ties(
task_tensors: List[torch.Tensor],
weights: torch.Tensor,
Expand Down
29 changes: 29 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,15 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
density=0.5,
)

# test magnitude_prune_svd re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_magnitude_prune_svd_reweighting",
combination_type="magnitude_prune_svd",
density=0.5,
)

# test cat re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat"
Expand Down Expand Up @@ -1099,6 +1108,15 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
density=0.5,
)

# test magnitude_prune re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2],
weight_list[:2],
"multi_adapter_magnitude_prune_reweighting",
combination_type="magnitude_prune",
density=0.5,
)

# test linear re-weighting with multiple adapters with only first adapter having non zero weight
model.add_weighted_adapter(
adapter_list[:2],
Expand Down Expand Up @@ -1142,18 +1160,29 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
density=0.5,
)

with pytest.raises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_magnitude_prune_reweighting_uneven_r",
combination_type="magnitude_prune",
density=0.5,
)

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_ties_svd_reweighting",
"multi_adapter_dare_linear_svd_reweighting",
"multi_adapter_dare_ties_svd_reweighting",
"multi_adapter_magnitude_prune_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
"multi_adapter_linear_reweighting_single_enabled",
"multi_adapter_ties_reweighting",
"multi_adapter_dare_linear_reweighting",
"multi_adapter_dare_ties_reweighting",
"multi_adapter_magnitude_prune_reweighting",
]
for new_adapter in new_adapters:
assert new_adapter in model.peft_config
Expand Down

0 comments on commit 25dec60

Please sign in to comment.