-
Notifications
You must be signed in to change notification settings - Fork 7
Universal loader
This page briefly describes how the universal model loader works and how to add your own model with non-standard inputs or loss functions.
Task start.py
initialize the Trainer
class in which the training will take place. File trainer.py
it contains 2 implemented classes for different types of GAN training, the base version requires passing some mandatory parameters:
conf: dictionary of training settings, this is the input json file
**kwargs: the dictionary must contain:
"G" # initialized generator
"D" # initialized discriminator
"start_epoch" # epoch for continuing training
"dataloader" # dataset loader
"optim_G" # generator optimizer
"optim_D" # discriminator optimizer
"gen_loss" # generator loss function
"disc_loss" # discriminator loss function
"z_dim" # the dimension of the generator vector
"device" # contains a device type
For more information about what a json file should contain, see the page configuring the config.
Now let's look at how to add your model to this loader by looking at start.py
:
# Loading all models
generators = {}
discriminators = {}
for name_model in get_py_modules('models'):
model = dynamic_import(f'models.{name_model}')
generators = {**generators, **model.generators}
discriminators = {**discriminators, **model.discriminators}
assert conf["Generator"] in generators.keys()
assert conf["Discriminator"] in discriminators.keys()
That is, the model from the json file is loaded as follows: the script traverses all python modules in the models
folder and searches inside two dictionaries, each of which is the following:
generators[NameGAN] = class NameGANgenerator
discriminators[NameGAN] = class NameGANdiscriminator
Thus, to implement this in semi-automatic mode, it is enough to move the user-defined file to the models folder with your model and add a couple of lines:
from utils import register
generators = register.ClassRegistry()
discriminators = register.ClassRegistry()
@generators.add_to_registry("NameGAN")
class Generator(nn.Module):
...
@discriminators.add_to_registry("NameGAN")
class Discriminator(nn.Module):
...
Accordingly, a decorator is written in the utils.register
, which adds the necessary class to the dictionary.
Absolutely similarly, you can add loss functions, only you need to add them to the file losses.py
.
Now let's analyze the case of non-standard loss functions or models. The base version of Trainer implements a universal train_loop, which refers to many functions:
def logger(self, data): # logs training
def save_model(self, epoch): # saves model weights
def generate_images(self, cnt=1): # returns cnt of butches of generated images
def train_disc(self, real_imgs, fake_imgs): # accepts real and fake images for training the discriminator
def train_gen(self, fake_imgs): # accepts the generated images for training the generator
def train_loop(self):
Accordingly, if we have a more complex loss function that requires not only real and fake images, but also, for example, additional parameters or a discriminator for input, such as WGAN-GP, then it is enough to inherit from base Trainer and redefine the necessary functions, for example, for WGAN-GP:
@trainers.add_to_registry(name="gp")
class GpGANTrainer(BaseGANTrainer):
def __init__(self, conf, **kwargs):
super().__init__(conf, **kwargs)
def train_disc(self, real_imgs, fake_imgs):
lambda_gp = self.conf["Loss_config"]["lambda_gp"]
return self.disc_loss(self.D, real_imgs, fake_imgs, lambda_gp)