Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I understand if the training is good? #1096

Open
GZeta95 opened this issue Jul 19, 2020 · 3 comments
Open

How do I understand if the training is good? #1096

GZeta95 opened this issue Jul 19, 2020 · 3 comments

Comments

@GZeta95
Copy link

GZeta95 commented Jul 19, 2020

Hi, thanks for your great work. I have trained my datasets with cycle_gan model for 400 epoch. I have generated results with different epochs and qualitatively they have been similar since the 100 epoch, that is, there are samples that are better at the end of the 100 epoch and others at the end of the 400 epoch. So i have plotted the loss functions from loss_log.txt and these are the results:
Cycle_A
Cycle_B
G_A
G_B
Idt_A
Idt_B
D_A
D_B

Which represent the values of functions at the end of each epoch.
Functions oscillate.
What values should I expect to know if the training has gone well? How do I choose how many epochs to tarin the net?
How do I understand if the network is in underfitting or overfitting?

Thank you a lot!

@junyanz
Copy link
Owner

junyanz commented Jul 19, 2020

Your loss plots look normal. Usually, cycle-consistency loss and identity loss decrease during training, while GAN losses oscillate. To evaluate the quality or detect overfitting/underfitting. you need to apply additional evaluation metrics to your training and test images. The metric is task-specific. See more discussion at #730.

@BenoitKAO
Copy link

BenoitKAO commented Sep 18, 2021

Hi GZeta95,
May I ask how to turn 'iterations' into 'epochs' from loss_log.txt?
Could you share your codes on plotting these diagrams?
Thank you very much.

@Djoels
Copy link

Djoels commented Jun 12, 2023

A bit late to the party, in reply to GZeta95's message: I used the last row of each epoch.

How I plotted the same (on my own dataset ofcourse):

  import matplotlib.pyplot as plt
  import re
  import pandas as pd
  from datetime import datetime
  
  def parse_date(logline):
      return datetime.strptime(logline[32:56], '%a %b   %d %H:%M:%S %Y')
  
  def read_log(filepath, run="final"):
      '''
      Read a i2i loss_log.txt from given path
      Parameters:
          - filepath: path to txt file
          - run: either final (for final run) or an integer indicating the number of the run (eg 1, 2, 3)
  
      return: a pandas dataframe with all relevant columns
      '''
      with open(filepath, 'r') as file:
          lines=file.readlines()
          run_nr = 0
          all_data = []
          current_logdate = ""
          for ix, line in enumerate(lines):
              if "Training Loss" in line:
                  run_nr+=1
                  current_logdate = parse_date(line).isoformat().replace("T", " ")
                  print(f"new log with starting logdate {current_logdate}")
              else:
                  #if ix < 10:
                  line = re.sub('[^0-9a-zA-Z ._]+', '', line)
                  #print(line)
                  line_arr = line.split(" ")
                  line_data = {k: v for k, v in zip(line_arr[0::2], line_arr[1::2])}
                  line_data.update({"run": run_nr, "run_date": current_logdate})
                  all_data.append(line_data)
                  #print({k: v for k, v in zip(k, v)})
          df = pd.DataFrame(all_data)
          df.epoch = df.epoch.astype(int)
          df.iters = df.iters.astype(int)
          for col in ["cycle_A", "cycle_B", "G_A", "G_B", "idt_A", "idt_B", "D_A", "D_B"]:
              df[col] = df[col].astype(float)
          if type(run) == int:
              df = df[df["run"] == run]
          elif run == "final":
              df =  df[df["run"] ==  run_nr]
          df = df.groupby(by="epoch", group_keys=False).last().reset_index().sort_values(by="epoch")
          return df
  
  df = read_log(loss_log_file, run="final")
  for col in ["cycle_A", "cycle_B", "G_A", "G_B", "idt_A", "idt_B", "D_A", "D_B"]:
      df.plot("epoch", col)
      plt.title(col)
      plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants