Skip to content

Commit

Permalink
ENH: Better error msg for replace_lora_weights_loftq when using a loc…
Browse files Browse the repository at this point in the history
…al model. (#2022)

Resolves #2020

If users want to use a local model, they need to pass the model_path
argument. The error message now says so.
  • Loading branch information
BenjaminBossan authored Aug 21, 2024
1 parent 25ab6c9 commit 95821e5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/peft/utils/loftq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError
from huggingface_hub.utils import LocalEntryNotFoundError
from safetensors import SafetensorError, safe_open
from transformers.utils import cached_file
Expand Down Expand Up @@ -271,10 +272,10 @@ def __init__(self, peft_model, model_path):
if model_path is None:
try:
model_path = snapshot_download(peft_model.base_model.config._name_or_path, local_files_only=True)
except AttributeError as exc:
except (AttributeError, HFValidationError) as exc:
raise ValueError(
"The provided model does not appear to be a transformers model. In this case, you must pass the "
"model_path to the safetensors file."
"The provided model does not appear to be a transformers model or is a local model. In this case, "
"you must pass the model_path argument that points to the safetensors file."
) from exc
except LocalEntryNotFoundError as exc:
raise ValueError(
Expand Down
39 changes: 39 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,6 +2108,45 @@ def my_callback(model, module_name):
torch.cuda.empty_cache()
gc.collect()

def test_replace_lora_weights_with_local_model(self):
# see issue 2020
torch.manual_seed(0)
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
device = "cuda"

with tempfile.TemporaryDirectory() as tmp_dir:
# save base model locally
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model.save_pretrained(tmp_dir)
del model

# load in 4bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)

# load the base model from local directory
model = AutoModelForCausalLM.from_pretrained(tmp_dir, quantization_config=bnb_config)
model = get_peft_model(model, LoraConfig())

# passing the local path directly works
replace_lora_weights_loftq(model, model_path=tmp_dir)
del model

# load the base model from local directory
model = AutoModelForCausalLM.from_pretrained(tmp_dir, quantization_config=bnb_config)
model = get_peft_model(model, LoraConfig())

# when not passing, ensure that users are made aware of the `model_path` argument
with pytest.raises(ValueError, match="model_path"):
replace_lora_weights_loftq(model)

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()


@require_bitsandbytes
@require_torch_gpu
Expand Down

0 comments on commit 95821e5

Please sign in to comment.