Skip to content

Commit

Permalink
Add gguf support for bloom (#33473)
Browse files Browse the repository at this point in the history
* add bloom arch support for gguf

* apply format

* small refactoring, bug fix in GGUF_TENSOR_MAPPING naming

* optimize bloom GGUF_TENSOR_MAPPING

* implement reverse reshaping for bloom gguf

* add qkv weights test

* add q_8 test for bloom
  • Loading branch information
VladOS95-cyber authored Sep 27, 2024
1 parent 3e039d3 commit 9d200cf
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ For now the supported model architectures are the architectures that have been v
- Qwen2
- Qwen2Moe
- Phi3
- Bloom

## Example usage

Expand Down
14 changes: 9 additions & 5 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,11 @@ def converted(self) -> Tokenizer:


class GPT2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
if not vocab:
vocab = self.original_tokenizer.encoder
if not merges:
merges = list(self.original_tokenizer.bpe_ranks)

tokenizer = Tokenizer(
BPE(
Expand All @@ -343,9 +345,11 @@ def converted(self) -> Tokenizer:
)
)

tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
add_prefix_space = False
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
if self.original_tokenizer.add_bos_token:
if getattr(self.original_tokenizer, "add_bos_token", False):
bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id
tokenizer.post_processor = processors.TemplateProcessing(
Expand Down
35 changes: 34 additions & 1 deletion src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tokenizers.models import BPE

from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter
from ..utils import logging
from ..utils.logging import tqdm

Expand Down Expand Up @@ -107,6 +107,19 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"bloom": {
"token_embd.weight": "transformer.word_embeddings.weight",
"token_embd_norm": "transformer.word_embeddings_layernorm",
"blk": "transformer.h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.dense",
"output.weight": "lm_head.weight",
"output_norm": "transformer.ln_f",
},
}


Expand Down Expand Up @@ -183,6 +196,13 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"bloom": {
"block_count": "n_layer",
"embedding_length": "hidden_size",
"attention.head_count": "n_head",
"vocab_size": "vocab_size",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -492,11 +512,24 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFBloomConverter(GPT2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
merges = self.original_tokenizer.merges
tokenizer = super().converted(vocab, merges)
return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"bloom": GGUFBloomConverter,
}


Expand Down
34 changes: 34 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
elif ".attn_k." in name:
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)

if architecture == "bloom" and "attn_qkv" in name:
num_heads = parsed_parameters["config"]["n_head"]
n_embed = parsed_parameters["config"]["hidden_size"]
if "weight" in name:
weights = reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = reverse_reshape_bias(weights, num_heads, n_embed)

for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
Expand All @@ -191,3 +199,29 @@ def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Opti
dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)


def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)

q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)

return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)


def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)

q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)

qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias
4 changes: 2 additions & 2 deletions src/transformers/models/bloom/tokenization_bloom_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
**kwargs,
):
super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
Expand Down
60 changes: 60 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class GgufIntegrationTests(unittest.TestCase):
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
bloom_model_id = "afrideva/bloom-560m-GGUF"
original_bloom_model_id = "bigscience/bloom-560m"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -69,6 +71,8 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"

example_text = "Hello"
Expand Down Expand Up @@ -385,6 +389,62 @@ def test_llama3_q4_0(self):
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_bloom_fp16(self):
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.fp16_bloom_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.bloom_model_id,
gguf_file=self.fp16_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I just want to say that I am very"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_bloom_q8_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.q8_bloom_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.bloom_model_id,
gguf_file=self.q8_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I just want to say that I am very"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_bloom_weights_conversion_fp16(self):
quantized_model = AutoModelForCausalLM.from_pretrained(
self.bloom_model_id,
gguf_file=self.fp16_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
original_model = AutoModelForCausalLM.from_pretrained(
self.original_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

quantized_state_dict = quantized_model.state_dict()
original_state_dict = original_model.state_dict()

for (quantized_name, quantized_param), (original_name, original_param) in zip(
quantized_state_dict.items(), original_state_dict.items()
):
if (
"self_attention.query_key_value" in quantized_name
and "self_attention.query_key_value" in original_name
):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit 9d200cf

Please sign in to comment.