Skip to content

Commit

Permalink
Properly handle empty metric_names passed to Trainer._filter_metrics (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Nov 9, 2023
1 parent 3830a55 commit 6f29ad6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
3 changes: 1 addition & 2 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ def __init__(
self.label = label
self.dataloader = ensure_data_spec(dataloader)

self.metric_names = []
if metric_names is not None:
if not isinstance(metric_names, list):
raise ValueError(f'``metric_names`` should be a list of strings, not a {type(metric_names)}')
self.metric_names = metric_names
self.metric_names = metric_names

self.subset_num_batches = subset_num_batches
self._eval_interval = None
Expand Down
13 changes: 6 additions & 7 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,13 @@ def _get_default_scheduler_frequency(schedulers: Optional[Union[Scheduler, Seque
def _filter_metrics(metrics: Dict[str, Metric], metric_names: Optional[List[str]]) -> Dict[str, Metric]:
"""Filter the metrics based on the given metric_names as regex strings (e.g. 'Accuracy', 'f1' for 'BinaryF1Score', 'Top-.' for 'Top-1 Accuracy' and 'Top-2 Accuracy', etc). If no metric_names are provided, all metrics will be returned."""
metrics = deepcopy(metrics)
if not metric_names:
if metric_names is None:
return metrics
else:
filtered_metrics = {}
for name, metric in metrics.items():
if any(re.match(f'.*{metric_name}.*', name, re.IGNORECASE) for metric_name in metric_names):
filtered_metrics[name] = metric
return filtered_metrics
filtered_metrics = {}
for name, metric in metrics.items():
if any(re.match(f'.*{metric_name}.*', name, re.IGNORECASE) for metric_name in metric_names):
filtered_metrics[name] = metric
return filtered_metrics


def _validate_precision(precision: Precision, device: Device):
Expand Down
40 changes: 40 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,46 @@ def test_compile_unsupported_torch_version_exception(self, caplog, model: Compos
auto_log_hparams=True,
compile_config={})

def test_eval_metrics(self):
model = SimpleModel()
train_dataloader = DataLoader(RandomClassificationDataset(size=1), batch_size=1)
all_metrics = model.get_metrics(is_train=False)

# Test default eval metrics
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=Evaluator(label='eval',
dataloader=DataLoader(RandomClassificationDataset(size=1), batch_size=1)),
)

assert trainer.state.eval_metrics['eval'] == all_metrics

# Test empty eval metrics
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=Evaluator(label='eval',
dataloader=DataLoader(RandomClassificationDataset(size=1), batch_size=1),
metric_names=[]),
)

assert trainer.state.eval_metrics['eval'] == {}

# Test selected eval metrics
single_metric = next(iter(all_metrics))
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=Evaluator(label='eval',
dataloader=DataLoader(RandomClassificationDataset(size=1), batch_size=1),
metric_names=[single_metric]),
)

eval_metric_names = trainer.state.eval_metrics['eval'].keys()
assert len(eval_metric_names) == 1
assert next(iter(eval_metric_names)) == single_metric


def _assert_optimizer_is_on_device(optimizer: torch.optim.Optimizer):
for state in optimizer.state.values():
Expand Down

0 comments on commit 6f29ad6

Please sign in to comment.