Skip to content

Commit

Permalink
Add begin_epoch to FairseqTask (#984)
Browse files Browse the repository at this point in the history
Summary:
Adds a begin_epoch hook to FairseqTask.
Pull Request resolved: fairinternal/fairseq-py#984

Differential Revision: D19429433

Pulled By: myleott

fbshipit-source-id: 367bd4d0d2d2bc995cca9ac151256c77ede36c83
  • Loading branch information
ebetica authored and facebook-github-bot committed Jan 17, 2020
1 parent 09eb023 commit 122fc1d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
6 changes: 5 additions & 1 deletion fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,12 @@ def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens)

def begin_epoch(self, epoch, model):
"""Hook function called before the start of each epoch."""
pass

def update_step(self, num_updates):
"""Task level update when number of update increases.
"""Task level update when number of updates increases.
This is called after the optimization step and learning rate
update at each iteration.
Expand Down
3 changes: 3 additions & 0 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def train(args, trainer, task, epoch_itr):
args, itr, epoch_itr.epoch, no_progress_bar='simple',
)

# task specific setup per epoch
task.begin_epoch(epoch_itr.epoch, trainer.get_model())

valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
for samples in progress:
Expand Down

0 comments on commit 122fc1d

Please sign in to comment.