High Batch Size with SD3 Dreambooth Destabilizes Training #8621
Replies: 7 comments 5 replies
-
This seems more like a technical discussion to me and very training specific. So, I am going to transfer this to "Discussions". Ccing @bghira @AmericanPresidentJimmyCarter to check if they have some things to suggest here. |
Beta Was this translation helpful? Give feedback.
-
Quick suggestion would be try to experiment with different LRs and LR schedulers that vary with batch sizes. |
Beta Was this translation helpful? Give feedback.
-
Gotcha, does sd3 typically so far seem to need a different lr than previous versions did? I've been running with roughly 1e-7 * batch_size before this, I'll give some other settings a try later. Attaching my training script for context, excuse the way it's a wreck after me bugfixing with it all week |
Beta Was this translation helpful? Give feedback.
-
yeah it needs a much lower LR. for what it's worth, 40 isn't considered a very high batch size for diffusion transformers, they appreciate it to be as high as you can push it |
Beta Was this translation helpful? Give feedback.
-
All noted, i'll try batch size 80 with 5e-7 and see if that works. |
Beta Was this translation helpful? Give feedback.
-
training seems much more stable after further tests so closing, thank you all for the advice. |
Beta Was this translation helpful? Give feedback.
-
Hey, |
Beta Was this translation helpful? Give feedback.
-
Describe the bug
I have been trying to train a slightly modified IP Adapter architecture for SD3 over the past few days and wrote the training script by copying the up to date weighting, noise and loss code from the train_dreamboothsd3.py script, while training nothing I could do would allow the model to train, after 2-3000 steps at batch size 40 lr 5e-6, the output would just turn to mush.
Now, after dropping to a batch size of 4 and lr 8e-7, the problem appears to have gone away completely, no hints of degradation 40,000 steps in.
Only other possible explanation is that around that time I also removed torch.autocast block around the model forward pass that shouldn't have been there given i was also using accelerate, but i don't think that was the source of the issue as that has been there for previous perfectly functional runs using a very similar script. I'll do an extra test to check some point over the next few days when i have a chance.
I have been using the newly modified weighting pushed a day or two ago and logit normal weighting. (same seemed to happen with the original sigma_sqrt weighting)
Reproduction
Run sd3 training using logit norm, presumably with a lora or something and use a batch size of at least 40 with lr 6e-6ish
Logs
System Info
4xA100 runpod server
Who can help?
@sayakpaul
Beta Was this translation helpful? Give feedback.
All reactions