Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GGUF Phi-3 #31844

Merged
merged 8 commits into from
Sep 10, 2024
Merged

Add support for GGUF Phi-3 #31844

merged 8 commits into from
Sep 10, 2024

Conversation

a8nova
Copy link
Contributor

@a8nova a8nova commented Jul 8, 2024

What does this PR do?

Add support for GGUF Phi-3
Use #31175 as a guide for adding gguf support for phi3

Fixes #31826

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@a8nova a8nova mentioned this pull request Jul 8, 2024
2 tasks
@a8nova
Copy link
Contributor Author

a8nova commented Jul 8, 2024

@khalidmohammedet

@amyeroberts
Copy link
Collaborator

cc @SunMarc

@a8nova
Copy link
Contributor Author

a8nova commented Jul 15, 2024

Hi @SunMarc @amyeroberts - while testing my changes I was running into a failure. I checked Qwen2's test and it also has same failure. I am attaching error below, I am running this on a colab T4 GPU, any ideas why qwen2's GGUF loading also failing for me? Command to repro below failure: python -m pytest -v ./tests/quantization/ggml/test_ggml.py::GgufIntegrationTests::test_qwen2_q4_0. Thanks!

============================================= FAILURES =============================================
_______________________________ GgufIntegrationTests.test_qwen2_q4_0 _______________________________

self = <ggml.test_ggml.GgufIntegrationTests testMethod=test_qwen2_q4_0>

    def test_qwen2_q4_0(self):
        tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
>       model = AutoModelForCausalLM.from_pretrained(
            self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
        )

tests/quantization/ggml/test_ggml.py:166: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/transformers/models/auto/auto_factory.py:564: in from_pretrained
    return model_class.from_pretrained(
src/transformers/modeling_utils.py:3583: in from_pretrained
    state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"]
src/transformers/modeling_gguf_pytorch_utils.py:146: in load_gguf_checkpoint
    weights = load_dequant_gguf_tensor(shape=shape, ggml_type=tensor.tensor_type, data=tensor.data)
src/transformers/integrations/ggml.py:526: in load_dequant_gguf_tensor
    values = dequantize_q6_k(data)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

data = memmap([[192,   8,  15, ..., 116, 211, 128],
        [144, 213, 102, ...,  56,  14, 129],
        [ 35,  71, 212, ...,...225,  22,   1],
        [112,  36, 159, ..., 225,  22,   1],
        [112,  36, 143, ..., 225,  22,   1]], dtype=uint8)

    def dequantize_q6_k(data):
        # C implementation
        # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
        # C struct definition
        # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152
        block_size = GGML_BLOCK_SIZES["Q6_K"]
        num_blocks = len(data) // block_size
    
>       data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
E       ValueError: cannot reshape array of size 63813120 into shape (723,105)

src/transformers/integrations/ggml.py:311: ValueError
--------------------------------------- Captured stderr call ---------------------------------------
Converting and de-quantizing GGUF tensors...:   0%|          | 0/291 [00:00<?, ?it/s]
========================================= warnings summary =========================================
../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1373
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1373: PytestConfigWarning: Unknown config option: doctest_glob
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================== short test summary info ======================================

@SunMarc
Copy link
Member

SunMarc commented Jul 16, 2024

Hi @a8nova ! The latest version of ggml that was released a few days ago is not compatible with the integration. We still need to fix this ! #31725 (comment). I will try to find some time to fix this but feel free to have a look ! In the meantime, you can also install the previous version of ggml and continue the PR from there !

@a8nova
Copy link
Contributor Author

a8nova commented Jul 17, 2024

got it, I will try with previous ggml version, thank you @SunMarc!

@a8nova
Copy link
Contributor Author

a8nova commented Jul 22, 2024

Hi @SunMarc - I am getting giberrish output for my test, i think it has to do with phi3 having a slightly different attention class where q, k, and v are one variable.

I am seeing below output while loading GGUF weights and since the model output differs on runs it looks like some weights are not being correctly initialized. Right now I am mapping attn_q, attn_v and attn_k to self_attn.qkv_proj but that feels incorrect or underlying code probably needs to support this attention mechanism?

Converting and de-quantizing GGUF tensors...: 100%|██████████| 195/195 [00:31<00:00,  6.25it/s]
Some weights of Phi3ForCausalLM were not initialized from the model checkpoint at microsoft/Phi-3-mini-4k-instruct-gguf and are newly initialized: ['model.layers.0.self_attn.qkv_proj.weight', 'model.layers.1.self_attn.qkv_proj.weight', 'model.layers.10.self_attn.qkv_proj.weight', 'model.layers.11.self_attn.qkv_proj.weight', 'model.layers.12.self_attn.qkv_proj.weight', 'model.layers.13.self_attn.qkv_proj.weight', 'model.layers.14.self_attn.qkv_proj.weight', 'model.layers.15.self_attn.qkv_proj.weight', 'model.layers.16.self_attn.qkv_proj.weight', 'model.layers.17.self_attn.qkv_proj.weight', 'model.layers.18.self_attn.qkv_proj.weight', 'model.layers.19.self_attn.qkv_proj.weight', 'model.layers.2.self_attn.qkv_proj.weight', 'model.layers.20.self_attn.qkv_proj.weight', 'model.layers.21.self_attn.qkv_proj.weight', 'model.layers.22.self_attn.qkv_proj.weight', 'model.layers.23.self_attn.qkv_proj.weight', 'model.layers.24.self_attn.qkv_proj.weight', 'model.layers.25.self_attn.qkv_proj.weight', 'model.layers.26.self_attn.qkv_proj.weight', 'model.layers.27.self_attn.qkv_proj.weight', 'model.layers.28.self_attn.qkv_proj.weight', 'model.layers.29.self_attn.qkv_proj.weight', 'model.layers.3.self_attn.qkv_proj.weight', 'model.layers.30.self_attn.qkv_proj.weight', 'model.layers.31.self_attn.qkv_proj.weight', 'model.layers.4.self_attn.qkv_proj.weight', 'model.layers.5.self_attn.qkv_proj.weight', 'model.layers.6.self_attn.qkv_proj.weight', 'model.layers.7.self_attn.qkv_proj.weight', 'model.layers.8.self_attn.qkv_proj.weight', 'model.layers.9.self_attn.qkv_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

For the tokenizer, I am using Llama's tokenizer since phi3 uses Llama tokenizer, that is what I read in the research paper, I have added the special tokens for phi3, when I dump the tokenizer from GGUF, it looks similar to tokenizer loaded by AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") so I don't think there is any issue there

@kibru9399
Copy link

kibru9399 commented Aug 19, 2024

Hi! I’m interested in working on this issue. Is anyone currently working on it? If not, I’d love to take it on as my first one.

@a8nova
Copy link
Contributor Author

a8nova commented Aug 19, 2024

Hi @kibru9399 - I am actively working on this, it is actually almost finalized.

@kibru9399
Copy link

kibru9399 commented Aug 19, 2024 via email

@a8nova
Copy link
Contributor Author

a8nova commented Aug 25, 2024

Hello @SunMarc @amyeroberts - This PR is ready for review!

A few remaining TODO's on my end:

  1. The generated output from gguf model microsoft/Phi-3-mini-4k-instruct-gguf is different from non gguf model microsoft/Phi-3-mini-4k-instruct. This is also true for Qwen2. This can be due to many different reasons but I want to make sure it is not due to a bug in the conversion code in this PR.
  2. Test this PR with phi3.5 weights
  3. There is a warning Merges were not in checkpoint, building merges on the fly. for microsoft/Phi-3-mini-4k-instruct-gguf, do you know if i need to look into this warning?

Thanks!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks for adding @a8nova ! I left a comment.

The generated output from gguf model microsoft/Phi-3-mini-4k-instruct-gguf is different from non gguf model microsoft/Phi-3-mini-4k-instruct. This is also true for Qwen2. This can be due to many different reasons but I want to make sure it is not due to a bug in the conversion code in this PR.

Since the model is quantized, this is normal that we don't get the same output. I would advise you to check that we get approximately the same perplexity in transformers compared to llama.cpp !

There is a warning Merges were not in checkpoint, building merges on the fly. for microsoft/Phi-3-mini-4k-instruct-gguf, do you know if i need to look into this warning?

Since Phi3 uses a BPE tokenizer, it needs to have a merges attribute. I'm not sure why we are not able to get it from the ggml file. For example see that we indeed that this in the model hosted on HF

Comment on lines 717 to 447

class GGUFPhi3Converter(LlamaConverter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
merges = self.original_tokenizer.merges
tokenizer = Tokenizer(BPE(vocab, merges, unk_token="<unk>", fuse_unk=True, byte_fallback=True))
tokenizer.decoder = decoders.Sequence(
[
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Replace("▁", " "),
]
)
# add the special tokens from phi3 tokenizer config
tokenizer.add_special_tokens(
[
AddedToken("</s>", rstrip=True, lstrip=False, normalized=False, special=True),
AddedToken("<|endoftext|>", normalized=False, special=True),
AddedToken("<|assistant|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder1|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder2|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder3|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder4|>", rstrip=True, normalized=False, special=True),
AddedToken("<|system|>", rstrip=True, normalized=False, special=True),
AddedToken("<|end|>", rstrip=True, normalized=False, special=True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you follow GGUFLlamaConverter structure ? (e.g adding vocab, tokenizer, merges methods ... )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@SunMarc
Copy link
Member

SunMarc commented Aug 26, 2024

LMK when you finished your TODO's, so that I can ask for a final review from a core maintainer !

@a8nova
Copy link
Contributor Author

a8nova commented Aug 28, 2024

Hi @SunMarc - For trying phi3.5 weights: There seems to be something wrong with phi3.5 template because model gives weird output. For example with input: Can you provide ways to eat combinations of bananas and dragonfruits?"
Model outputs:

Output:
 <s>Canyouprovidewaystoeatcombinationsofbananasanddragonfruits?

Assistant:
Certainly! Here's a simple recipe for a banana and dragon fruit smoothie:

Ingredients:
- 1 ripe banana, peeled and sliced
- 1/2 cup of diced dragon fruit (also known as pitaya)
- 1

And with input "Hello", model gives (HTML) output:

Output:
 <s>Hello, World!</h1>

<p>This is a paragraph.</p>

<ul>
  <li>Item 1</li>
  <li>Item 2</li>dict

<script>
  console.log("Hello, World!");
</script>

<img src="image.jpg

My understanding was that phi 3.5 is identical to phi3 for its architecture but there seems to be some differences and I don't see it documented anywhere.. I won't worry about this issue in current scope.

For merges missing issue, I don't see a merges in the ggml file though it exists in the original checkpoints. I see the merges are built if they don't exist in the ggml checkpoints which slows things down. Is the merges missing a bug in the safetensors -> ggml conversion? how do you suggest we move forward? ( I also tried converting via ggml-my-repo a few phi3 variants but they are still missing the merges in the converted ggml files)

if the merges missing from checkpoint is not a blocker then I think this PR is ready for a final review, if it is a blocker we can get to the bottom of it first. Thanks!

@SunMarc
Copy link
Member

SunMarc commented Sep 2, 2024

For merges missing issue, I don't see a merges in the ggml file though it exists in the original checkpoints. I see the merges are built if they don't exist in the ggml checkpoints which slows things down. Is the merges missing a bug in the safetensors -> ggml conversion? how do you suggest we move forward? ( I also tried converting via ggml-my-repo a few phi3 variants but they are still missing the merges in the converted ggml files)

if the merges missing from checkpoint is not a blocker then I think this PR is ready for a final review, if it is a blocker we can get to the bottom of it first. Thanks!

Could you check quickly if we have the issue on a llama gguf model ? If not, I think that we should probably fix this issue with a PR in llama.cpp repo. We would have to fix this conversion script https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py.

@a8nova
Copy link
Contributor Author

a8nova commented Sep 2, 2024

Hi @SunMarc! When I try via an already converted gguf llama like TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF which is also used inside test_ggml.py, I don't see the missing merges. But a llama gguf I just converted using gguf-my-repo: a8nova/TinyLlama-1.1B-Chat-v1.0-Q4_0-GGUF is missing merges. So it looks like a recent change (in llama.cpp repo?) broke the conversion for all models...

@SunMarc
Copy link
Member

SunMarc commented Sep 3, 2024

Hi @SunMarc! When I try via an already converted gguf llama like TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF which is also used inside test_ggml.py, I don't see the missing merges. But a llama gguf I just converted using gguf-my-repo: a8nova/TinyLlama-1.1B-Chat-v1.0-Q4_0-GGUF is missing merges. So it looks like a recent change (in llama.cpp repo?) broke the conversion for all models...

Thanks for the investigation ! We should definitely try to fix this as this can be quite troublesome if we have to recreate the merges for all recent llama models.

@a8nova
Copy link
Contributor Author

a8nova commented Sep 3, 2024

I am able to checkout llama.cpp and repro the missing merges bug locally and I also have a fix.

The bug shows up for the llama family models only. I have narrowed it down to the _set_vocab_llama_hf() routine inside convert_hf_to_gguf.py#L806. I am able to fix it by passing load_merges=True to that line like:

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True, n_vocab=len(tokens))

It looks like we will only go inside _set_vocab_llama_hf() if the self._set_vocab_sentencepiece() which is wrapped by a try-catch inside the LlamaModel class fails which it does in my case since there is no tokenizer.model file for the llama model or phi3 but there is a tokenizer.json.

If this fix makes sense to you, I create a PR in the llama.cpp repo!

Update: I think the bug can also happen for any model class that calls _set_vocab_sentencepiece(). There is also the case where a tokenizer.model file is present in which case _create_vocab_sentencepiece() never throws an exception, and when we are back in _set_vocab_sentencepiece() load_merges is also not passed as True here, so this would be another place we would have to fix it

@SunMarc
Copy link
Member

SunMarc commented Sep 4, 2024

I would first open an issue on llama.cpp to see if anyone there have more insights around that + explain the potential fix ! Thanks for investigating ! I really appreciate that.

@a8nova
Copy link
Contributor Author

a8nova commented Sep 4, 2024

Done! ggerganov/llama.cpp#9309

  • I noticed that ggml.py in this repo builds the merges if they're missing. Is the primary drawback of missing merges just the performance slowdown, or could it lead to other issues as well?

  • Any ideas why phi3.5 is not working?

@SunMarc
Copy link
Member

SunMarc commented Sep 4, 2024

I noticed that ggml.py in this repo builds the merges if they're missing. Is the primary drawback of missing merges just the performance slowdown, or could it lead to other issues as well?

Well, it will just lead to increased loading time.

Any ideas why phi3.5 is not working?

Not really. Can you check that the phi3.5 on transformers works correctly? If that works correctly, you can try to gguf it then load it in transformers.

@a8nova
Copy link
Contributor Author

a8nova commented Sep 6, 2024

Hi @SunMarc! I would like to finalize this PR if possible, are there any blockers or can we get final reviews and get this merged in? Thanks!

Regarding phi3.5, non gguf checkpoints have same behavior, something is off about the template/prompt

- I missed one place when resolving conflict
- I also made a mistake with tests_ggml.py and now has been fixed to reflect
its master version.
@SunMarc
Copy link
Member

SunMarc commented Sep 9, 2024

No blockers @a8nova. I really appreciate your work !

As for the merges issue, this is something we need to tackle separately from this PR.
Can you create a separate issue for phi3.5 ? It seems that something is not working on transformers and we need to fix this first. Then, the gguf issue will be solved naturally.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Boom, awesome! Thanks for the contribution @a8nova

@a8nova
Copy link
Contributor Author

a8nova commented Sep 10, 2024

Thank you, @SunMarc! Before I file an issue for the phi3.5 bug, let me first confirm. I am suspicious of what I saw, maybe I did something wrong.
Also, I am curious, what is the purpose of requiring torch to run on the GPU in test_ggml.py? I have been running the tests locally on my CPU..

@require_torch_gpu
@slow
class GgufIntegrationTests(unittest.TestCase):
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

@SunMarc
Copy link
Member

SunMarc commented Sep 10, 2024

Also, I am curious, what is the purpose of requiring torch to run on the GPU in test_ggml.py? I have been running the tests locally on my CPU

I guess it is simply to test the model faster. Maybe you can try on GPU to see if this fixes anything ?

@SunMarc SunMarc merged commit 96429e7 into huggingface:main Sep 10, 2024
23 checks passed
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* Update docs for GGUF supported models

* Add tensor mappings and define class GGUFPhi3Converter

* Fix tokenizer

* Working version

* Attempt to fix some CI failures

* Run ruff format

* Add vocab, merges, decoder methods like LlamaConverter

* Resolve conflicts since Qwen2Moe was added to gguf

- I missed one place when resolving conflict
- I also made a mistake with tests_ggml.py and now has been fixed to reflect
its master version.
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
* Update docs for GGUF supported models

* Add tensor mappings and define class GGUFPhi3Converter

* Fix tokenizer

* Working version

* Attempt to fix some CI failures

* Run ruff format

* Add vocab, merges, decoder methods like LlamaConverter

* Resolve conflicts since Qwen2Moe was added to gguf

- I missed one place when resolving conflict
- I also made a mistake with tests_ggml.py and now has been fixed to reflect
its master version.
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Update docs for GGUF supported models

* Add tensor mappings and define class GGUFPhi3Converter

* Fix tokenizer

* Working version

* Attempt to fix some CI failures

* Run ruff format

* Add vocab, merges, decoder methods like LlamaConverter

* Resolve conflicts since Qwen2Moe was added to gguf

- I missed one place when resolving conflict
- I also made a mistake with tests_ggml.py and now has been fixed to reflect
its master version.
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Update docs for GGUF supported models

* Add tensor mappings and define class GGUFPhi3Converter

* Fix tokenizer

* Working version

* Attempt to fix some CI failures

* Run ruff format

* Add vocab, merges, decoder methods like LlamaConverter

* Resolve conflicts since Qwen2Moe was added to gguf

- I missed one place when resolving conflict
- I also made a mistake with tests_ggml.py and now has been fixed to reflect
its master version.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for GGUF Phi-3
6 participants