Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2Moe GGUF loading support #33264

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ For now the supported model architectures are the architectures that have been v
- LLaMa
- Mistral
- Qwen2
- Qwen2Moe

## Example usage

Expand Down
38 changes: 37 additions & 1 deletion src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"qwen2moe": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -161,6 +176,18 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"qwen2moe": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
Expand Down Expand Up @@ -579,7 +606,15 @@ def tokenizer(self, proto):
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None

tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=unk_token,
fuse_unk=True,
byte_fallback=True,
)
)

special_tokens = []

Expand Down Expand Up @@ -693,6 +728,7 @@ def converted(self) -> Tokenizer:
GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
}


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
else:
updated_architecture = architecture

if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"

if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture} not supported")

Expand Down
39 changes: 35 additions & 4 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
import unittest

from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
from transformers.testing_utils import (
require_gguf,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_torch_available


Expand All @@ -32,6 +37,7 @@ class GgufIntegrationTests(unittest.TestCase):
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"

Expand All @@ -45,6 +51,7 @@ class GgufIntegrationTests(unittest.TestCase):

q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"

Expand Down Expand Up @@ -166,7 +173,10 @@ def test_f16(self):
def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
self.mistral_model_id,
gguf_file=self.q4_0_mistral_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
Expand All @@ -178,7 +188,10 @@ def test_mistral_q4_0(self):
def test_qwen2_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
self.qwen2_model_id,
gguf_file=self.q4_0_qwen2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
Expand All @@ -187,6 +200,21 @@ def test_qwen2_q4_0(self):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_qwen2_moe_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_moe_model_id,
gguf_file=self.q4_0_qwen2_moe_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_llama3_q4_0_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -199,7 +227,10 @@ def test_llama3_q4_0_tokenizer(self):
def test_llama3_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
self.llama3_model_id,
gguf_file=self.q4_llama3_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
Expand Down
Loading