Skip to content

Commit

Permalink
extend test and patch bug (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Mar 3, 2023
1 parent 3618c63 commit 11d1c51
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
3 changes: 2 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,10 +2604,11 @@ def eval(
evaluators = self.state.evaluators

for evaluator in evaluators:
eval_subset_num_batches = evaluator.subset_num_batches if subset_num_batches == -1 else subset_num_batches
self._eval_loop(
dataloader=evaluator.dataloader,
dataloader_label=evaluator.label,
subset_num_batches=subset_num_batches,
subset_num_batches=eval_subset_num_batches,
metrics=self.state.eval_metrics[evaluator.label],
)
if eval_passed_in:
Expand Down
20 changes: 12 additions & 8 deletions tests/trainer/test_trainer_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,27 @@ def test_trainer_eval_loop():
assert trainer.state.eval_metrics['eval']['MulticlassAccuracy'].compute() != 0.0


def test_trainer_eval_subset_num_batches():
@pytest.mark.parametrize('evaluator_on_init,subset_on_init', [[True, True], [True, False], [False, False]])
def test_trainer_eval_subset_num_batches(evaluator_on_init: bool, subset_on_init: bool):
dataset = RandomClassificationDataset()
eval_dataloader = DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
)

# Construct the trainer
event_counter_callback = EventCounterCallback()
trainer = Trainer(
model=SimpleModel(),
callbacks=[event_counter_callback],
eval_dataloader=eval_dataloader if evaluator_on_init else None,
eval_subset_num_batches=1 if subset_on_init else -1,
)

# Evaluate the model
dataset = RandomClassificationDataset()
eval_dataloader = DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
)
trainer.eval(
eval_dataloader=eval_dataloader,
subset_num_batches=1,
eval_dataloader=eval_dataloader if not evaluator_on_init else None,
subset_num_batches=1 if not subset_on_init else -1,
)

# Ensure that just one batch was evaluated
Expand Down

0 comments on commit 11d1c51

Please sign in to comment.