Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low-shot training #5

Open
xinyouduogao opened this issue Jun 13, 2021 · 6 comments
Open

Low-shot training #5

xinyouduogao opened this issue Jun 13, 2021 · 6 comments

Comments

@xinyouduogao
Copy link

How to train the Lecam-gan on the low-shot image generation datasets,THX.

@hytseng0509
Copy link
Collaborator

Hi,

We modify the code of DiffAug to train on the low-shot image generation task.

@liang-hou
Copy link

Hi,

We modify the code of DiffAug to train on the low-shot image generation task.

Hi, would you mind sharing the source code of the low-shot generation experiments? It will help us a lot.

@aprilycliu
Copy link

Hello,
I'd also try to implement lecam loss on DiffAug low shot. Could you share the train script please? Thanks!

@SushkoVadim
Copy link

Hi,

I would add my vote to this discussion, it would be very helpful to have a look at the training script or the modified training.loss.py file.
I tried to implement the method myself on top of the DiffAugm, but I still did not manage to reproduce the results from suppl. Table 4.
Thanks in advance!

@roadjiang
Copy link
Contributor

This experiment only shows in the supplementary and does not belong to the main paper.
We'd love to release the code but it may require additional approvals.
We will try our best and see what we can do.

@SushkoVadim
Copy link

Hi,
Thanks a lot for answering! I understand that the clearing process for open sourcing can be time-consuming and burdensome.
Potentially to simplify the answer, could I please ask you to share a comment on my attempts to reproduce the training? Perhaps, I did not know some implementation details that appear to be important. This can also be beneficial for others trying to reproduce the results for the low-shot training.

My modification to the DiffAugm was to add the lecam regularizatoin in the training.loss.py module.

  1. Particularly, I added a simple EMA tracker for both the real and fake logits to the StyleGAN2Loss Class:
self.val_ema_real = val_EMA()
self.val_ema_fake = val_EMA()

class val_EMA():
    def __init__(self, ema_decay=0.99):
        self.ema_decay = ema_decay
        self.mem_value = 0

    def add_step(self, cur_values):
        self.mem_value = self.ema_decay * self.mem_value + (1 - self.ema_decay) * cur_values.detach()

    def get_cur_val(self):
        return self.mem_value
  1. During training I add new logit values to the ema accumulation, and then add the regularization to the objective functions:
    # for fakes
    loss_emaCR_fake = 0
    if do_emaCR:
        self.val_ema_fake.add_step(gen_logits)  
        loss_emaCR_fake = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(self.val_ema_real.get_cur_val() - gen_logits)))
        
    with torch.autograd.profiler.record_function('Dgen_backward'):
        (loss_Dgen + loss_emaCR_fake).mean().mul(gain).backward()
    ....
    # for reals
    loss_emaCR_real = 0
    if do_emaCR:
        self.val_ema_real.add_step(real_logits)
        loss_emaCR_real = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(real_logits - self.val_ema_fake.get_cur_val())))
 
    with torch.autograd.profiler.record_function(name + '_backward'):
        (real_logits * 0 + loss_Dreal + loss_Dr1 + loss_emaCR_real).mean().mul(gain).backward()

I run the training for the same 300 kimg, I use self.cr_ema_lambda = 0.0001, self.ema_decay = 0.99, which corresponds to the description from the supplementary material. After the training is finished, I measure the following best FID across epochs:

Use LeCam CR? Metrics Animal Face - Cat Animal Face Dog Obama Panda Grumpy Cat
- reported 42.10 58.47 47.09 12.10 27.21
Yes reported 33.16 54.88 33.16 10.16 24.93
- reproduced 40.20 67.12 48.31 14.44 27.09
Yes reproduced 39.55 64.84 50.80 14.82 29.66

Thus, I am able to reproduce the original numbers from DiffAugm repository. However, the results after adding the lecam CR seem not to match to Table 4, this step is even harmful for 3/5 of the datasets.

It would be indeed very helpful if we figure out where lies my misunderstanding.
Regards,
Vadim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants