Skip to content

Commit

Permalink
fix: add sampler.set_epoch(epoch) to training loop to shuffle data …
Browse files Browse the repository at this point in the history
…in distributed mode
  • Loading branch information
ImahnShekhzadeh committed Jul 29, 2024
1 parent 443b000 commit f8bc646
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions examples/images/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def train(rank, total_num_gpus, argv):
]
),
)
sampler = DistributedSampler(dataset) if FLAGS.parallel else None
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size_per_gpu,
sampler=DistributedSampler(dataset) if FLAGS.parallel else None,
sampler=sampler,
shuffle=False if FLAGS.parallel else True,
num_workers=FLAGS.num_workers,
drop_last=True,
Expand Down Expand Up @@ -155,9 +156,11 @@ def train(rank, total_num_gpus, argv):

global_step = 0 # to keep track of the global step in training loop

with trange(num_epochs, dynamic_ncols=True) as epoch_bar:
for epoch in epoch_bar:
epoch_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
if sampler is not None:
sampler.set_epoch(epoch)

with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
for step in step_pbar:
Expand Down

0 comments on commit f8bc646

Please sign in to comment.