-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there anyone success to train this model? #34
Comments
The reconstruction images are like solid image. |
Can you show the reconstruction images after training? |
@bridenmj How much epochs do you use? Are you working on the ImageNet Pretrain? |
Yes I'm working on ImageNet pretraining, It passed 12000 steps. The output image looks always the same. So, I tried LFQ in my own autoencoder, the training works well. It looks like there is something wrong in magvit2 model architecture. |
Actually I reimplement the model structure to align with the magvit2 paper. But I find that the LFQ Loss is negative and the recon loss will get converage easily with or without GAN. The reconstructed images are vague but not the solid color. What about you? @Jihun999 |
Ok, I will reimplement the model first. Thank you for your comment. |
Hey, is it possible to share the code modification for model architecture alignment? Thanks a lot! |
someone i know has trained it successfully. |
wow, could i know who did it. |
@RobertLuo1 @Jihun999 @lucidrains If you successfully trained this model, would you like to share the pretrained weights and the modified model code? |
Hello there, I tried with only MSE and then also the other losses, and also with/without attend_space layers. All work but I did not try to tune hyperparameters.. |
thank you for sharing this Marina! I'll see if I can find the bug, and worse comes to worse, can always rewrite the training code in pytorch lightning |
Hi, recently we have devoted a lot to training the tokenizer in Magvit2, and now we have open source the tokenizer trained with imagenet. Feel free to use that. The project page is https://github.com/TencentARC/Open-MAGVIT2. Thanks @lucidrains so much for your reference code and discussions! |
Hey @lucidrains, I trained a MAGVIT2 tokenizer without modifying your implementation of the accelerate framework. As others have experienced, I initially saw just a solid block in the results/sampled.x.gif files. However, upon loading the model weights from my most recent checkpoint, I was able to get pretty good reconstructions in a sample script that I wrote that performs inference without using the accelerate framework. Additionally, the reconstruction MSE scores were consistent with the ones observed in your training script. This means that whatever bug others are experiencing is not the result of flawed model training, but rather something going wrong with the gif rendering. *Note: the first file is the saved gif in the |
Please check Tencent's https://github.com/TencentARC/Open-MAGVIT2, which is based on this implementation but modify some parts |
Thanks for your reply. I did came across the Open-MAGVIT2 repo, but correct me if I'm wrong, I don't think they've implemented the video tokenizer yet? |
@vincentcartillier Oh yes but they are developing a video tokenizer... They had the same problem during training the image tokenizer and finally fixed it. I guess @RobertLuo1 can answer your question ^^ |
typically, if the recon loss is below 0.03, you will see an outline of the video. What you encountered may indicate the architecture is difficult to converge as I manually re-implement magvit2, it will quickly produce the reconstruction within a few hundred steps. In order to debug, you can first skip quantization and only use encoder's output as decoder's input to adjust the whole model's structure. When it's done, then you add quantization as it hard to train. BTW, the repo uses a 2d gan which takes samples frames as input, which is not aligned with the paper. You can use a 3d vqgan instead. But from my point of view, discriminator's training is not the most import part. The encoder and decoder's structures matters. |
Got it. Thanks a lot for all the tips. I'll try these out. In the meantime, do you think you could share your re-implementation of magvit2? I'm assuming this is based of this repo. |
Sorry, I can't because it's an internal project and is still under development > <. The implementation is not based on this project. I followed the google's magvit-v1 jax repo and modified it. The adjustments between v1 and v2 are minimal. |
But I use the vector-quantize-pytorch's LFQ as it used in here. |
Got it. Totally understandable. Thanks a lot for all the tips I'll give it a try! |
@vincentcartillier I encountered the same difficult convergence problem as described by @Jason3900 , but before that, I found that the learning rate is not set correctly. When I checked |
Got it thanks so much for the pointer. I think the reason why you're seeing such a low learning rate is because of the use of If you haven't used LFQ, are you using FSQ instead? (finite scalar quantization) or another thing? Could you try running the same thing (same learning rate) with LFQ and see if it works? - it would be great to see if that's the source of the problem we're facing. |
Yep, but still it shouldn't take that many steps to get the reconstructed result only with 1 Video. It may indicate that the model is too hard to converge. |
@vincentcartillier sure!
to test the convergence of this enc / dec arch, I just drop the quantizer and use continuous representation with dim 512
yes I run the same setting with LFQ, but can't get a converged result in 40k steps. So I re-implemented the enc / dec arch (using code in this repo with little modification) according to the paper. a surprising result I got this time. Still follow @Jason3900 's suggestion, skip quantization |
Amazing! Would you be comfortable sharing the code modifications you've made? ( maybe via a PR or just sharing your fork). |
I also got to something kinda working. This is the same code, ie no modifications, same settings as my initial post above. Except I've changed the learning rate, or rather I've turned of the |
Yes, this is the modified class CausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size,
pad_mode = 'constant',
s_stride = 1,
t_stride = 1,
**kwargs
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
self.pad_mode = pad_mode
time_pad = time_kernel_size - 1
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = (t_stride, s_stride, s_stride)
self.conv = nn.Conv3d(
chan_in,
chan_out,
kernel_size,
stride = stride,
**kwargs
)
def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
x = F.pad(x, self.time_causal_padding, mode = pad_mode)
return self.conv(x) Next is to implement |
@JingwWu @vincentcartillier |
@Jason3900 This is a great open source work! I will check it out in detail. |
Thanks! |
I tried to train this model few days. However, the reconstruction results always abnormal. If there is anyone success to train this model, can you tell me some tips for training?
The text was updated successfully, but these errors were encountered: