Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for "leaf Variable that requires grad" Error in In-Place Operation #1372

Merged
merged 14 commits into from
Mar 4, 2024

Conversation

DopeorNope-Lee
Copy link
Contributor

Issue:
The current implementation in the file peft/tuners/lora/layer.py encounters a runtime error due to an in-place operation on a leaf variable that requires gradient computation. Specifically, the error is triggered by the following line of code:

result += (after_A @ embedding_B) * scaling

This line uses the += operator, which modifies the result tensor in-place. When result is a leaf variable with requires_grad=True, such in-place operations are incompatible with PyTorch's autograd system, leading to a "RuntimeError: a leaf Variable that requires grad is being used in an in-place operation".

Solution:
To resolve this issue, I propose modifying the in-place operation to a regular operation that creates a new tensor. This change ensures that the original value of result is not altered, thereby maintaining compatibility with the autograd system. The updated line of code is as follows:

result = result + (after_A @ embedding_B) * scaling

This modification allows the computation to proceed without altering the original result tensor, thus avoiding the RuntimeError and ensuring proper gradient calculation during backpropagation.

Therefor, before modification I encountered
"RuntimeError: a leaf Variable that requires grad is being used in an in-place operation."

However, I revised like that it solved!

Best

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the fix sounds good !
Can you propagate the fix to other LoRA types (Conv, etc.), also can you share a small reproducible snippet of the bug?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DopeorNope-Lee
Copy link
Contributor Author

DopeorNope-Lee commented Jan 24, 2024

@younesbelka I tried CASUAL_LM and mixtral model..!

Also my target modules includes Embedding layer, so this code revision was implemented in Embedding class in LoRA layer.

So my PR is only related in Embedding class.

@DopeorNope-Lee
Copy link
Contributor Author

@younesbelkada also I changed lora/bnb.py for fix the error which written in below
RuntimeError: Output 0 of MatMul8bitLtBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @DopeorNope-Lee for the fixes, makes sense. I agree with Younes that it would be better to make this change across all LoRA layers. However, that could be part of separate PR in case the issue is only noticed for the layers changed in this PR.

@DopeorNope-Lee
Copy link
Contributor Author

@pacman100 I will fix all of the layers and will mention again..!

Thanks for your recommendation..!

@younesbelkada
Copy link
Contributor

Thanks @DopeorNope-Lee ! Let us know if you need any help

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Feb 6, 2024

Thanks @DopeorNope-Lee for providing this PR. Do you have a minimal code example that demonstrates the error with the in-place operation?

Edit: Is this related to #1425?

@DopeorNope-Lee
Copy link
Contributor Author

DopeorNope-Lee commented Feb 6, 2024

Thanks @BenjaminBossan I saw the PR in your mention. But it have some difference.

First, above the codes are related to full-finetuning(someone say it continuous pre-train). However, in place errors occurs in fine-tuning cases.

So I revised all of a+=b operations to a = a + b like this.

@younesbelkada Hi, I revised all of operators in Lora layers. However I saw 'conflicts that must be resolved'.

I think the previous code used copy method, but the recent version uses clone method.

So I also updated it!

If there is more improvement or other issues, let me know I will follow up.

@BenjaminBossan
Copy link
Member

First, above the codes are related to full-finetuning(someone say it continuous pre-train). However, in place errors occurs in fine-tuning cases.

Thanks for explaining. I wrote a small test to check full fine-tuning and didn't encounter any error when training. Could you please provide a minimal example? This is also important to have as a unit test so that we can prevent regressions in the future.

@DopeorNope-Lee
Copy link
Contributor Author

DopeorNope-Lee commented Feb 7, 2024

@BenjaminBossan Did you add embed_tokens in target module?

Detail information is follow:

model: TomGrc/FusionNet_7Bx2_MoE_14B

    -load_in_4bit=True,
    -trust_remote_code=True,
    -torch_dtype=torch.bfloat16
    -attn_implementation='flash_attention_2'


LoRA:
target_module = ['embed_tokens', 'q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj','up_proj', 'lm_head']
bias='none'
task_type='CAUSAL_LM'


learning_scheduler:  cosine

@BenjaminBossan
Copy link
Member

Did you add embed_tokens in target module?

I added this test to test_custom_models.py:

    @parameterized.expand(TEST_CASES)
    def test_training_full_finetuning(self, test_name, model_id, config_cls, config_kwargs):
        # check if training the full model works, no error is raised; only check custom models since they are small
        # so full finetuning shouldn't take too long
        model = self.transformers_class.from_pretrained(model_id)
        config = config_cls(
            base_model_name_or_path=model_id,
            **config_kwargs,
        )
        model = get_peft_model(model, config)
        model = model.to(self.torch_device)
        model.train()
        model.requires_grad_(True)  # make all parameters trainable

        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        inputs = self.prepare_inputs_for_testing()

        for _ in range(5):
            optim.zero_grad()
            output = model(**inputs)[0]
            loss = output.sum()
            loss.backward()
            optim.step()

These tests include examples with embedding layers in the target_modules.

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan How about trying the mixtral model with LoRA layer? I think this error did not occurr from Full-finetuning, fine-tuning. LoRA fine-tuning

@BenjaminBossan
Copy link
Member

How about trying the mixtral model with LoRA layer?

I see, thanks. I don't have a machine available to test full mixtral, so I couldn't test it.

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan Then may I help you with testing the LoRA fine-tuning mixtral?

@BenjaminBossan
Copy link
Member

Then may I help you with testing the LoRA fine-tuning mixtral?

Thanks, that's not necessary. My main concern is to have some kind of test to ensure that we don't have regressions in the future, but maybe that's not easily possible here? I tried a tiny mixtral model I found on HF but that didn't trigger any error. Running full mixtral won't work on our CI.

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan Have you tried fine-tuning with LoRA?

I attempted fine-tuning using LoRA.

It's an instruct-tune utilizing LoRA, not full fine-tuning.

But your previous code looks full-fine-tuning, I think..

@BenjaminBossan
Copy link
Member

Well, we have a bunch of tests for fine-tuning with LoRA. I wrote the test for full fine-tuning because you had said earlier:

above the codes are related to full-finetuning

Maybe I misunderstood what you meant.

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan
It seems we both misunderstood each other thinking the other discussion( #1425 ) was about full-finetuning.

However, the issue I'm referring to here is the instruct tuning (fine-tuning) using LoRA adapters through load_in_4bit.

After hearing your words, I've also started a training test with a new initial environment.

It would be great if we could run this together and share related information.

I'm very grateful for your contribution to this open-source.

@BenjaminBossan
Copy link
Member

It seems we both misunderstood each other

These things happen, glad we're now on the same boat.

After hearing your words, I've also started a training test with a new initial environment.

It would be great if we could run this together and share related information.

Thanks. If you have something to share, let me know. At the end of the day, if we have confirmation that this PR fixes a real world issue, even if we cannot add it to a unit test, it's fine with me. Maybe we can upload the script somewhere and add a link to it as a comment.

@DopeorNope-Lee
Copy link
Contributor Author

DopeorNope-Lee commented Feb 9, 2024

@BenjaminBossan Sure!

Moreover, I'm sharing recent test result using Mixtral LoRA finetuning.

image

After that, I removed the latest version of peft and implemented the library (PR) I modified.

pip install -q -U git+https://github.com/DopeorNope-Lee/peft-modified.git

image

Now, It runs really well!

@BenjaminBossan
Copy link
Member

@DopeorNope-Lee Could you please run make style on your changes? This might require to update ruff to work correctly. I think the issue is the excess empty line 493.

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan I run make style and push it!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes and the fruitful discussion. LGTM.

Moreover, I'm sharing recent test result using Mixtral LoRA finetuning.

Is it possible for you to share the script?

@DopeorNope-Lee
Copy link
Contributor Author

DopeorNope-Lee commented Feb 13, 2024

@BenjaminBossan Sure, I usually used Platypus code with bash-file.

# BASE and Data

BASIS=TomGrc/FusionNet_7Bx2_MoE_14B
DATA=DopeorNope/adaption_new_v1.2

# fine_tuning output
output_directory=ko-mixtral_v1.4

# merged output
final_dir=Ko-Mixtral-v1.4-MoE-7Bx2

# repository directory
repo_dir=DopeorNope/Ko-Mixtral-v1.4-MoE-7Bx2


python finetune.py \
    --base_model $BASIS \
    --data-path $DATA \
    --output_dir $output_directory \
    --lora_target_modules '[embed_tokens, q_proj, k_proj, v_proj, o_proj, gate_proj, down_proj,up_proj, lm_head]' \
    --batch_size 256 \
    --micro_batch_size 2 \
    --num_epochs 1 \
    --learning_rate 1e-5 \
    --lr_scheduler 'cosine' \
    --cutoff_len 8192 \
    --lora_r 64 \
    --resume_from_checkpoint False \
    --lora_alpha 16 \
    --train_on_inputs False \
    --add_eos_token False \
    --group_by_length False \
    --prompt_template_name alpaca \
    --warmup_steps 10 \
    --lora_dropout 0.05 \
    

@BenjaminBossan
Copy link
Member

I think this PR should be ready to be merged, right? @DopeorNope-Lee could you please fix the small merge conflict?

@younesbelkada @pacman100 do you have further comments?

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan Sure, I fixed it.!

@DopeorNope-Lee
Copy link
Contributor Author

@pacman100 @younesbelkada @BenjaminBossan Hi?
I think we are ready to merge it.
I tried to merge it, but two pending reviews were left.
Therefore, I failed to merge it.

Could you approve my PR?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @DopeorNope-Lee for the fixes, LGTM!

@BenjaminBossan
Copy link
Member

@DopeorNope-Lee I think the last merge with main resulted in incorrect code, with weights being merged twice, resulting in the failing CI. Could you please take a look?

@DopeorNope-Lee
Copy link
Contributor Author

@BenjaminBossan
Hmm...
It seems there was an error while resolving a conflict with the existing code.
So, I've brought over the current version of layer.py and made modifications in the same manner as my previous PR changes. Can you check this for me?

@BenjaminBossan
Copy link
Member

Thanks @DopeorNope-Lee for correcting the resolution to work as intended. This LGTM now.

@BenjaminBossan BenjaminBossan merged commit 34f3fba into huggingface:main Mar 4, 2024
14 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants