You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm trying to fine-tune my model, which is BLIP-2, using flash attention 2 on OPT 2.7B, but using FA2 produces significantly higher loss than using eager attention mode, which seems similar to issues reported previously (#26498, #28925, #28142).
From the comments from those issues, the best way to use fa2 normally is to load the model in full precision and train the model with autocast context.
However, when using accelerate library, accelerator.prepare function converts the model into a specified dtype (for me, bf16) including layer norm.
I guess this caused the problem for me, but I'm not sure.
Could you check this behavior and give any suggestions? I'm using transformers==4.40.0.dev0, accelerate==0.23.0 and flash_attn==2.5.5.
Or if there is any more detail that I have to elaborate on, please let me know.
Thanks in advance :)
The text was updated successfully, but these errors were encountered:
Hey 🤗 thanks for opening an issue! We try to keep the github issues for bugs/feature requests.
Could you ask your question on the forum instead? I'm sure the community will be of help!
Also think you should be able to prevent accelerate from preparing to a different dtype!
Thanks!
Hi, I'm trying to fine-tune my model, which is BLIP-2, using flash attention 2 on OPT 2.7B, but using FA2 produces significantly higher loss than using eager attention mode, which seems similar to issues reported previously (#26498, #28925, #28142).
From the comments from those issues, the best way to use fa2 normally is to load the model in full precision and train the model with autocast context.
However, when using
accelerate
library,accelerator.prepare
function converts the model into a specified dtype (for me,bf16
) including layer norm.I guess this caused the problem for me, but I'm not sure.
Could you check this behavior and give any suggestions? I'm using
transformers==4.40.0.dev0
,accelerate==0.23.0
andflash_attn==2.5.5
.Or if there is any more detail that I have to elaborate on, please let me know.
Thanks in advance :)
The text was updated successfully, but these errors were encountered: