Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable lora #8674

Merged
merged 6 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/llm/docs/peft.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ Prefix-tuning[论文](https://arxiv.org/abs/2101.00190)
target_modules=target_modules,
r=lora_rank,
lora_alpha=2 * lora_rank,
merge_weights=True
)
model = LoRAModel(model, lora_config)
model.mark_only_lora_as_trainable()
Expand Down Expand Up @@ -92,7 +91,7 @@ Parameters:
默认为 0.0,dropout的比例设置,float 类型

--merge_weights
默认为 False,模型推理时,是否进行base model 权重和 LoRA 权重的合参操作,bool 类型
默认为 False,接口将被废弃。请使用model.merge()或model.unmerge()替代。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predictor看起来也是需要进行更改


--trainable_bias
指定可训练的 bias, 可选项 ['lora', 'all']
Expand Down
2 changes: 1 addition & 1 deletion llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def __init__(
if config.lora_path is not None:
lora_config = LoRAConfig.from_pretrained(config.lora_path)
dtype = lora_config.dtype
lora_config.merge_weights = True
elif config.prefix_path is not None:
prefix_config = PrefixConfig.from_pretrained(config.prefix_path)
dtype = prefix_config.dtype
Expand All @@ -292,6 +291,7 @@ def __init__(
self.model = LoRAModel.from_pretrained(
model=self.model, lora_path=config.lora_path, lora_config=lora_config
)
self.model.merge()
if config.prefix_path is not None:
prefix_tuning_params = get_prefix_tuning_params(self.model)
self.model = PrefixModelForCausalLM.from_pretrained(
Expand Down
8 changes: 6 additions & 2 deletions llm/tools/merge_lora_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ def lora_process(name, lora_config, state_dict, device, lora_state_dict=None):


def merge_old_lora(lora_config, args):
lora_config.merge_weight = True
lora_config.merge_weights = True
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
dtype=lora_config.dtype,
)
model = LoRAModel.from_pretrained(model, args.lora_path)
model.eval()
try:
model.merge()
model.eval()
except:
model.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是要什么兼容了? 如果进行了merge,model.eval也是需要打开?

model_state_dict = model.model.state_dict()
for key in list(model_state_dict):
if "lora" in key:
Expand Down
8 changes: 1 addition & 7 deletions paddlenlp/peft/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,5 @@
# limitations under the License.

from .lora_config import LoRAConfig
from .lora_layers import (
ColumnParallelLoRALinear,
ColumnParallelLoRAMergedLinear,
LoRALinear,
LoRAMergedLinear,
RowParallelLoRALinear,
)
from .lora_layers import ColumnParallelLoRALinear, LoRALinear, RowParallelLoRALinear
from .lora_model import LoRAModel
5 changes: 5 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@
"We will automatically set `use_quick_lora` to `False` to avoid potential inconsistencies."
)
self.use_quick_lora = False
if self.merge_weights:
logger.error(

Check warning on line 98 in paddlenlp/peft/lora/lora_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_config.py#L98

Added line #L98 was not covered by tests
"'merge_weights' is deprecated and will be removed in a future version. "
"Please apply model.merge() or model.unmerge() to merge/unmerge LoRA weight to base model."
)

@property
def scaling(self):
Expand Down
Loading
Loading