Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object Store Logger Race Condition + EMA Fix #1552

Merged
23 changes: 10 additions & 13 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,13 @@ def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State,
raise NotImplementedError(
f'Unknown checkpointing interval: {interval.unit}. Must be TimeUnit.EPOCH or TimeUnit.BATCH.')

last_checkpoint_batch: Optional[Time] = None

def checkpoint_save_interval(state: State, event: Event):
nonlocal last_checkpoint_batch
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

# Always checkpoint at end of training
if elapsed_duration >= 1.0:
# if doing batch-wise checkpointing, and we saved a checkpoint at the batch_checkpoint event
# right before the epoch_checkpoint event, do not save another checkpoint at the epoch_checkpoint
# event if the batch count didn't increase.
if state.timestamp.batch != last_checkpoint_batch:
last_checkpoint_batch = state.timestamp.batch
return True
return True
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

if save_event == Event.EPOCH_CHECKPOINT:
count = state.timestamp.epoch
Expand All @@ -78,7 +71,6 @@ def checkpoint_save_interval(state: State, event: Event):
raise RuntimeError(f'Invalid save_event: {save_event}')

if event == save_event and int(count) % int(interval) == 0:
last_checkpoint_batch = state.timestamp.batch
return True

return False
Expand Down Expand Up @@ -257,7 +249,7 @@ class CheckpointSaver(Callback): # noqa: D101
event.

weights_only (bool): If ``True``, save only the model weights instead of the entire training state.
This parmeter must be ``False`` when using DeepSpeed. Default: ``False``.
This parameter must be ``False`` when using DeepSpeed. Default: ``False``.


num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints
Expand Down Expand Up @@ -303,6 +295,7 @@ def __init__(
if not callable(checkpoint_save_interval):
checkpoint_save_interval = checkpoint_periodically(checkpoint_save_interval)
self.checkpoint_save_interval = checkpoint_save_interval
self.last_checkpoint_batch: Optional[Time] = None

self.checkpoint_save_path = checkpoint_save_path

Expand Down Expand Up @@ -336,15 +329,17 @@ def fit_start(self, state: State, logger: Logger) -> None:
raise NotImplementedError('weights_only=True is not supported when using DeepSpeed.')

def batch_checkpoint(self, state: State, logger: Logger):
if self.checkpoint_save_interval(state, Event.BATCH_CHECKPOINT):
if self.checkpoint_save_interval(
state, Event.BATCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(
state,
logger,
self.get_log_level(state, default=LogLevel.BATCH),
)

def epoch_checkpoint(self, state: State, logger: Logger):
if self.checkpoint_save_interval(state, Event.EPOCH_CHECKPOINT):
if self.checkpoint_save_interval(
state, Event.EPOCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(
state,
logger,
Expand All @@ -363,6 +358,8 @@ def get_state_dict(self, state):
}

def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel):
self.last_checkpoint_batch = state.timestamp.batch

is_deepspeed = is_model_deepspeed(state.model)

if is_deepspeed and '{rank}' not in self.checkpoint_filename.filename:
Expand Down
32 changes: 29 additions & 3 deletions composer/loggers/object_store_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,10 @@ def _enqueue_uploads(self):
break
else:
# If any worker died, then it's impossible to recover since the file was already popped off of the queue,
# so break.Some files may not be uploaded.
# so break. Some files may not be uploaded.
break

# Yield the lock, so it can be acquired by `self.log_file_artifact`
time.sleep(0.2)
time.sleep(0.2) # Yield lock for `self.log_file_artifact`

def get_file_artifact(
self,
Expand All @@ -394,6 +393,33 @@ def get_file_artifact(
overwrite=overwrite,
progress_bar=progress_bar)

def fit_end(self, state: State, logger: Logger):
self.wait_for_workers()

def eval_end(self, state: State, logger: Logger):
self.wait_for_workers()

def predict_end(self, state: State, logger: Logger):
self.wait_for_workers()

def wait_for_workers(self):
"""Wait for all tasks to be completed.

This is called after fit/eval/predict. If we don't wait, then a worker might not schedule
an upload before the interpreter is shutdown and garbage collection begins. While
post_close logic ensures existing uploads are completed, trying to schedule new uploads
during this time will error.
"""
# Verify enqueue thread has processed all tasks
while True:
with self._object_lock:
if len(self._logged_objects) == 0:
break
time.sleep(0.2) # Yield lock for enqueue thread
# Verify all tasks have been completed
while not self._file_upload_queue.empty():
time.sleep(0.2)

def post_close(self):
# Shutdown logic:
# 1. Signal to the enqueue thread that all uploads are enqueued. Specifically.
Expand Down