Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] <title> Lora 微调 14B 模型后,转GPTQ 量化模型报错 #824

Closed
2 tasks done
micronetboy opened this issue Dec 19, 2023 · 2 comments
Closed
2 tasks done

Comments

@micronetboy
Copy link

micronetboy commented Dec 19, 2023

是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?

  • 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions

该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?

  • 我已经搜索过FAQ | I have searched FAQ

当前行为 | Current Behavior

Lora 微调后,合并模型,保存到 ./merged_14b。在转 gptq Int4 量化时报错:

Traceback (most recent call last):
  File "/home/venus/IFunAIPlat/MetaHuman/finetune/gptq.py", line 29, in <module>
    model.quantize(examples)
  File "/home/venus/anaconda3/envs/qwenvllmgptq/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/venus/anaconda3/envs/qwenvllmgptq/lib/python3.10/site-packages/auto_gptq/modeling/_base.py", line 359, in quantize
    layer(layer_input, **additional_layer_inputs)
  File "/home/venus/anaconda3/envs/qwenvllmgptq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/venus/anaconda3/envs/qwenvllmgptq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/venus/.cache/huggingface/modules/transformers_modules/Qwen/Qwen-14B-Chat/cdaff792392504e679496a9f386acf3c1e4333a5/modeling_qwen.py", line 610, in forward
    attn_outputs = self.attn(
  File "/home/venus/anaconda3/envs/qwenvllmgptq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/venus/anaconda3/envs/qwenvllmgptq/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/venus/.cache/huggingface/modules/transformers_modules/Qwen/Qwen-14B-Chat/cdaff792392504e679496a9f386acf3c1e4333a5/modeling_qwen.py", line 432, in forward
    query = apply_rotary_pos_emb(query, q_pos_emb)
  File "/home/venus/.cache/huggingface/modules/transformers_modules/Qwen/Qwen-14B-Chat/cdaff792392504e679496a9f386acf3c1e4333a5/modeling_qwen.py", line 1345, in apply_rotary_pos_emb
    t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

下面是 qgpt Readme 里的转量化的代码。

from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging

logging.basicConfig(
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)

pretrained_model_dir = "./merged_14b"
quantized_model_dir = "./merged_14b_lora_gptq"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
examples = [
    tokenizer(
        "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
    )
]

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad
)

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)

# save quantized model
model.save_quantized(quantized_model_dir)

# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)

# push quantized model to Hugging Face Hub.
# to use use_auth_token=True, Login first via huggingface-cli login.
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)

# alternatively you can save and push at the same time
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)

# load quantized model to the first GPU
#model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")

# download quantized model from Hugging Face Hub and load to the first GPU
# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)

# inference with model.generate
#print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))

# or you can also use pipeline
#pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
#print(pipeline("auto-gptq is")[0]["generated_text"])

期望行为 | Expected Behavior

No response

复现方法 | Steps To Reproduce

No response

运行环境 | Environment

- OS:
- Python:
- Transformers:
- PyTorch:
- CUDA (`python -c 'import torch; print(torch.version.cuda)'`):

备注 | Anything else?

No response

@jklj077
Copy link
Contributor

jklj077 commented Dec 26, 2023

我们已在README中加入了模型量化的相关说明,并针对该问题在AutoGPTQ中提交了相应PR。建议再尝试下。

@jklj077 jklj077 closed this as completed Dec 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants