From 5430803cb511d873340141f56b8062240e4f638d Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Wed, 3 Jul 2024 16:06:35 +0800 Subject: [PATCH 1/4] support gguf fp16 --- src/transformers/integrations/ggml.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 5c2d72c345ecf9..a45c43da2b0991 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -36,6 +36,7 @@ # Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md GGML_TYPES = { "F32": 0, + "F16": 1, "Q4_0": 2, "Q8_0": 8, "Q2_K": 10, @@ -489,6 +490,8 @@ def dequantize_q5_k(data): def load_dequant_gguf_tensor(shape, ggml_type, data): if ggml_type == GGML_TYPES["F32"]: values = data + elif ggml_type == GGML_TYPES["F16"]: + values = data elif ggml_type == GGML_TYPES["Q8_0"]: values = dequantize_q8_0(data) elif ggml_type == GGML_TYPES["Q4_0"]: From 02261d3c7826da25949780d3ee7180202e2587fe Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Wed, 3 Jul 2024 16:10:10 +0800 Subject: [PATCH 2/4] support gguf bf16 with pytorch --- src/transformers/integrations/ggml.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index a45c43da2b0991..9282070c989639 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -44,6 +44,7 @@ "Q4_K": 12, "Q5_K": 13, "Q6_K": 14, + "BF16": 30, } # The Blocksizes are reported in bytes @@ -492,6 +493,11 @@ def load_dequant_gguf_tensor(shape, ggml_type, data): values = data elif ggml_type == GGML_TYPES["F16"]: values = data + elif ggml_type == GGML_TYPES["BF16"]: + import torch + data_uint8 = data.view(np.uint8) + tensor_uint8 = torch.from_numpy(data_uint8) + values = tensor_uint8.view(torch.bfloat16).float().numpy() elif ggml_type == GGML_TYPES["Q8_0"]: values = dequantize_q8_0(data) elif ggml_type == GGML_TYPES["Q4_0"]: From 2c42437da5e08a65178743d48d9d0ddde9c96846 Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Fri, 5 Jul 2024 08:41:00 +0800 Subject: [PATCH 3/4] add gguf f16 test --- tests/quantization/ggml/test_ggml.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index a5866094a1cc6f..e42900a1d51b44 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -33,6 +33,7 @@ class GgufIntegrationTests(unittest.TestCase): mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF" llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" + tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF" q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" @@ -45,6 +46,7 @@ class GgufIntegrationTests(unittest.TestCase): q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf" q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf" q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" + f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" example_text = "Hello" @@ -149,6 +151,18 @@ def test_q8_0(self): EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_f16(self): + tokenizer = AutoTokenizer.from_pretrained(self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id + ).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\n5. Node.js" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_mistral_q4_0(self): tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id) model = AutoModelForCausalLM.from_pretrained( From a8310052da3a78aee7fdade1a155383c2413d3e1 Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Wed, 24 Jul 2024 23:42:51 +0800 Subject: [PATCH 4/4] remove bf16 --- src/transformers/integrations/ggml.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 34ad64b3be67e3..47f3f0cf8d57b4 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -44,7 +44,6 @@ "Q4_K": 12, "Q5_K": 13, "Q6_K": 14, - "BF16": 30, } # The Blocksizes are reported in bytes @@ -493,11 +492,6 @@ def load_dequant_gguf_tensor(shape, ggml_type, data): values = data elif ggml_type == GGML_TYPES["F16"]: values = data - elif ggml_type == GGML_TYPES["BF16"]: - import torch - data_uint8 = data.view(np.uint8) - tensor_uint8 = torch.from_numpy(data_uint8) - values = tensor_uint8.view(torch.bfloat16).float().numpy() elif ggml_type == GGML_TYPES["Q8_0"]: values = dequantize_q8_0(data) elif ggml_type == GGML_TYPES["Q4_0"]: