Skip to content

Commit

Permalink
Remove eval batch split (#1576)
Browse files Browse the repository at this point in the history
* remove eval batch split

* refactor
  • Loading branch information
mvpatel2000 authored Sep 30, 2022
1 parent 20d073d commit 1b7ffce
Showing 1 changed file with 19 additions and 41 deletions.
60 changes: 19 additions & 41 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,48 +1721,26 @@ def _eval_train_metrics(self, device_batch):
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()'

with torch.no_grad(), model_eval_mode(self.state.model):
# Retry until we successfully complete evaluation
while True:
found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends
try:
for eval_microbatch in self._train_data_spec.split_batch(device_batch, self.state.eval_batch_split):
with get_precision_context(self.state.precision):
if hasattr(self._original_model, 'validate'): # backwards compatibility check
warnings.warn(
'Using validate() is no longer supported and will be removed in a future version. Please use eval_forward() instead.'
)
assert isinstance(self._original_model.validate, Callable)
eval_outputs, target = self._original_model.validate(eval_microbatch)

for _, metric in self.state.train_metrics.items():
metric.update(eval_outputs, target)
else:
eval_outputs = self._original_model.eval_forward(eval_microbatch, self.state.outputs)
for _, metric in self.state.train_metrics.items():
self._original_model.update_metric(
eval_microbatch,
eval_outputs,
metric,
)
with torch.no_grad(),\
model_eval_mode(self.state.model),\
get_precision_context(self.state.precision):
if hasattr(self._original_model, 'validate'): # backwards compatibility check
warnings.warn(
'Using validate() is no longer supported and will be removed in a future version. Please use eval_forward() instead.'
)
assert isinstance(self._original_model.validate, Callable)
eval_outputs, target = self._original_model.validate(device_batch)

except RuntimeError as e:
if self.state.auto_grad_accum and _is_cuda_oom(e):
log.debug((f"Rank {dist.get_global_rank()} OOM'd."))
found_cuda_oom = 1
else:
raise
if self.state.auto_grad_accum:
# Propagate across all ranks if any rank hit CUDA OOM
found_cuda_oom = self._device.tensor_to_device(torch.tensor([found_cuda_oom], dtype=torch.uint8))
dist.all_reduce(found_cuda_oom, reduce_operation='MAX')
if found_cuda_oom.item() == 1:
device_batch_size = self._train_data_spec.get_num_samples_in_batch(device_batch)
_adjust_eval_batch_split(self.state, device_batch_size)
# Skip return and rerun after handling oom
continue
# Return if we've successfully completed eval without OOMing.
return
for _, metric in self.state.train_metrics.items():
metric.update(eval_outputs, target)
else:
eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs)
for _, metric in self.state.train_metrics.items():
self._original_model.update_metric(
device_batch,
eval_outputs,
metric,
)

def _run_evaluators(self, event: Event):
"""Runs evaluators periodically during training."""
Expand Down

0 comments on commit 1b7ffce

Please sign in to comment.