Skip to content

Commit

Permalink
Add Qwen2Moe GGUF loading support (huggingface#33264)
Browse files Browse the repository at this point in the history
* update gguf doc, config and tensor mapping

* add qwen2moe architecture support, GGUFQwen2MoeConverter and q4 unit tests

* apply code style fixes

* reformat files

* assign GGUFQwen2Converter to qwen2_moe
  • Loading branch information
VladOS95-cyber authored and BernardZach committed Dec 5, 2024
1 parent af05eb9 commit 897ae85
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ For now the supported model architectures are the architectures that have been v
- LLaMa
- Mistral
- Qwen2
- Qwen2Moe

## Example usage

Expand Down
38 changes: 37 additions & 1 deletion src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"qwen2moe": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -123,6 +138,18 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"qwen2moe": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
Expand Down Expand Up @@ -244,7 +271,15 @@ def tokenizer(self, proto):
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None

tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=unk_token,
fuse_unk=True,
byte_fallback=True,
)
)

special_tokens = []

Expand Down Expand Up @@ -358,6 +393,7 @@ def converted(self) -> Tokenizer:
GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
}


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
else:
updated_architecture = architecture

if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"

if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture} not supported")

Expand Down
39 changes: 35 additions & 4 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
import unittest

from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
from transformers.testing_utils import (
require_gguf,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_torch_available


Expand All @@ -33,6 +38,7 @@ class GgufIntegrationTests(unittest.TestCase):
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"

Expand All @@ -59,6 +65,7 @@ class GgufIntegrationTests(unittest.TestCase):

q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"

Expand Down Expand Up @@ -298,7 +305,10 @@ def test_f16(self):
def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
self.mistral_model_id,
gguf_file=self.q4_0_mistral_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
Expand All @@ -310,7 +320,10 @@ def test_mistral_q4_0(self):
def test_qwen2_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
self.qwen2_model_id,
gguf_file=self.q4_0_qwen2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
Expand All @@ -319,6 +332,21 @@ def test_qwen2_q4_0(self):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_qwen2_moe_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_moe_model_id,
gguf_file=self.q4_0_qwen2_moe_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_llama3_q4_0_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -331,7 +359,10 @@ def test_llama3_q4_0_tokenizer(self):
def test_llama3_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
self.llama3_model_id,
gguf_file=self.q4_llama3_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
Expand Down

0 comments on commit 897ae85

Please sign in to comment.