Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Queston about training loop #669

Open
AnabetsyR opened this issue Jun 24, 2022 · 4 comments
Open

Queston about training loop #669

AnabetsyR opened this issue Jun 24, 2022 · 4 comments

Comments

@AnabetsyR
Copy link

Hi! I'm trying to fork the repo and add some functionality for an experiment. But that requires an addition in the training loop. I've read the documentation and the code but I can't seem to understand where the training loop itself is defined. Can somebody point me in the right direction?

Thanks in advance!

@jonhare
Copy link
Collaborator

jonhare commented Jun 24, 2022

In trial.py see method run (looping over the epochs):

def run(self, epochs=1, verbose=-1):

and _fit_pass (looping over the batches):

def _fit_pass(self, state):

I'd have thought most additions to the training loop can be added via one of the many callback hooks rather than modifying the source itself though

@AnabetsyR
Copy link
Author

Thanks for getting back to me! I'm trying to integrate stochastic weight averaging as in swa_utils from Pytorch. The way they implemented swa is like a wrapper on top of the torch optimizer (see their example below). Based on this, it seems I will need to pass the swa_model, the optimizer, and at least the swa_scheduler. And then I have to handle parameter update after swa kicks in during the training loop. Do you have any suggestions to go about doing this? Sorry, I'm new to both Torchbearer and SWA... I really appreciate any suggestions.

loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
>>> T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>> for i in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if i > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)

@jonhare
Copy link
Collaborator

jonhare commented Jun 24, 2022

really, really (,really!) untested, but something like this should be equivalent based on the above and guessing the correct indentation:

loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

swa_scheduler = SWALR(optimizer, swa_lr=0.05)
swa_start = 160

@torchbearer.callbacks.on_step_training
def swa_callback(state):
	if state[torchbearer.EPOCH] > swa_start:
		swa_model.update_parameters(model) #or avoiding the global access: swa_model.update_parameters(state[torchbearer.MODEL])
		swa_scheduler.step()
	else:
		scheduler.step()

trial = torchbearer.Trial(model, optimizer, loss_fn, callbacks=[swa_callback])
trial = trial.with_train_generator(loader)
trial.run(300)

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)

@AnabetsyR
Copy link
Author

@jonhare Thank you so much! I've spent the weekend playing around with it. At first it was working weird. Instead of just passing the swa_callback alone, I was adding it to my list of callbacks (which has other things in it), and then passing the whole list. So it seems some things don't play well together but I suspect it's due to some unnecessary schedulers etc in there. Of course now I have to combine the necessary portions but it's working!

Note that I added the update_bn line before the trial as it looked like it wasn't updating properly. Hopefully this is correct!

I'm extremely grateful that you took time our of your day to help me out! You really saved mmme a ton of headache. And I love torchbearer so much that I didn't want to switch the whole thing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants