Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gguf support for gpt2 #34044

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ For now the supported model architectures are the architectures that have been v
- Bloom
- Falcon
- StableLM
- GPT2

## Example usage

Expand Down
22 changes: 22 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"gpt2": {
"token_embd": "transformer.wte",
"blk": "transformer.h",
"position_embd": "transformer.wpe",
"output_norm": "transformer.ln_f",
"attn_norm": "ln_1",
"attn_qkv": "attn.c_attn",
"attn_output.weight": "attn.c_proj.weight",
"attn_output.bias": "attn.c_proj.bias",
"ffn_norm": "ln_2",
"ffn_up": "mlp.c_fc",
"ffn_down": "mlp.c_proj",
},
}


Expand Down Expand Up @@ -271,6 +284,14 @@
"attention.layer_norm_epsilon": "layer_norm_eps",
"vocab_size": "vocab_size",
},
"gpt2": {
"block_count": "n_layer",
"context_length": "n_ctx",
"embedding_length": "n_embd",
"feed_forward_length": "feed_forward_length",
"attention.head_count": "n_head",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -600,6 +621,7 @@ def converted(self) -> Tokenizer:
"bloom": GGUFGPTConverter,
"falcon": GGUFGPTConverter,
"stablelm": GGUFGPTConverter,
"gpt2": GGUFGPTConverter,
}


Expand Down
17 changes: 17 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
else:
weights = reverse_reshape_bias(weights, num_heads, n_embed)

if architecture == "gpt2":
if (
"attn_qkv.weight" in name
or "ffn_down.weight" in name
or "ffn_up.weight" in name
or "attn_output.weight" in name
):
# Original transpose implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
weights = weights.T
if name == "output.weight":
# output.weight has conflicts with attn_output.weight in name checking
# we have to explicitly check that name is exactly output.weight
name = "lm_head.weight"
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
continue

for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt2/tokenization_gpt2_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(
**kwargs,
):
super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
Expand Down
53 changes: 53 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class GgufIntegrationTests(unittest.TestCase):
stablelm_model_id = "afrideva/stablelm-3b-4e1t-GGUF"
stablelm2_model_id = "afrideva/stablelm-2-1_6b-GGUF"
original_stablelm2_model_id = "stabilityai/stablelm-2-1_6b"
gpt2_model_id = "mradermacher/gpt2-GGUF"
gpt2_original_model_id = "openai-community/gpt2"
gpt2_xl_model_id = "RichardErkhov/openai-community_-_gpt2-xl-gguf"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -87,6 +90,9 @@ class GgufIntegrationTests(unittest.TestCase):
fp16_falcon7b_model_id = "falcon-7b-fp16.gguf"
q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf"
fp16_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B.gguf"
fp16_gpt2_model_id = "gpt2.f16.gguf"
q8_gpt2_model_id = "gpt2.Q8_0.gguf"
q6_k_gpt2_xl_model_id = "gpt2-xl.Q6_K.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -476,6 +482,53 @@ def test_bloom_weights_conversion_fp16(self):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param)

def test_gpt2_q8(self):
tokenizer = AutoTokenizer.from_pretrained(self.gpt2_model_id, gguf_file=self.q8_gpt2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.gpt2_model_id,
gguf_file=self.q8_gpt2_model_id,
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt")
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I'm sorry. I'm sorry. I"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_gpt2_weights_conversion_fp16(self):
quantized_model = AutoModelForCausalLM.from_pretrained(
self.gpt2_model_id,
gguf_file=self.fp16_gpt2_model_id,
torch_dtype=torch.float16,
)
original_model = AutoModelForCausalLM.from_pretrained(
self.gpt2_original_model_id,
torch_dtype=torch.float16,
)

quantized_state_dict = quantized_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in quantized_state_dict:
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])

def test_gpt2_xl_Q6_K(self):
tokenizer = AutoTokenizer.from_pretrained(self.gpt2_xl_model_id, gguf_file=self.q6_k_gpt2_xl_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.gpt2_xl_model_id,
gguf_file=self.q6_k_gpt2_xl_model_id,
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt")
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I'm a newbie to the world of"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

@unittest.skip(reason="Heavy memory")
def test_falcon40b_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.falcon40b_model_id, gguf_file=self.q2_k_falcon40b_model_id)
Expand Down