Skip to content

Commit

Permalink
Better try - consolidate logic and ensure shutdown occurs on error
Browse files Browse the repository at this point in the history
  • Loading branch information
srowen committed Aug 14, 2024
1 parent 68508ae commit 2b2c0fb
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,6 @@ def flush_shard(self) -> None:
"""Flush cached samples to storage, creating a new shard."""
raise NotImplementedError

def _check_event_set(self) -> None:
"""Check if the event is set and raise an exception."""
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
raise Exception('One of the threads failed. Check other traceback for more ' +
'details.')

def write(self, sample: dict[str, Any]) -> None:
"""Write a sample.
Expand All @@ -262,7 +253,12 @@ def write(self, sample: dict[str, Any]) -> None:
Args:
sample (Dict[str, Any]): Sample dict.
"""
self._check_event_set()
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
raise Exception('One of the threads failed. Check other traceback for more ' +
'details.')
# Execute the task if there is no exception in any of the async threads.
new_sample = self.encode_sample(sample)
new_sample_size = len(new_sample) + self.extra_bytes_per_sample
Expand All @@ -276,7 +272,6 @@ def _write_index(self) -> None:
"""Write the index, having written all the shards."""
if self.new_samples:
raise RuntimeError('Internal error: not all samples have been written.')
self._check_event_set()
basename = get_index_basename()
filename = os.path.join(self.local, basename)
obj = {
Expand All @@ -298,21 +293,30 @@ def _write_index(self) -> None:

def finish(self) -> None:
"""Finish writing samples."""
if self.new_samples:
self.flush_shard()
self._reset_cache()
self._write_index()
logger.debug(f'Waiting for all shard files to get uploaded to {self.remote}')
self.executor.shutdown(wait=True)
if self.event.is_set():
# If an error occurred, cancel any outstanding uploads
self.cancel_future_jobs()
else:
# Otherwise finish any remaining uploads and write index, and wait for completion
if self.new_samples:
self.flush_shard()
self._reset_cache()
self._write_index()
logger.debug(f'Waiting for all shard files to get uploaded to {self.remote}')
self.executor.shutdown(wait=True)

if self.remote and not self.keep_local:
shutil.rmtree(self.local, ignore_errors=True)

# Final check, in case error occurred in an upload during shutdown
if self.event.is_set():
raise Exception('One of the threads failed. Check other traceback for more ' +
'details.')

def cancel_future_jobs(self) -> None:
"""Shutting down the executor and cancel all the pending jobs."""
# Beginning python v3.9, ThreadPoolExecutor.shutdown() has a new parameter `cancel_futures`
self.executor.shutdown(wait=False, cancel_futures=True)
if self.remote and not self.keep_local:
shutil.rmtree(self.local, ignore_errors=True)

def exception_callback(self, future: Future) -> None:
"""Raise an exception to the caller if exception generated by one of an async thread.
Expand All @@ -330,8 +334,9 @@ def exception_callback(self, future: Future) -> None:
if exception:
# Set the event to let other pool thread know about the exception
self.event.set()
# re-raise the same exception
raise exception
# log exception; raising does not propagate to caller as this is called
# from Future thread and would just be logged as an unexpected error
logger.error(f"Exception in writer thread: {exception}")

def __enter__(self) -> Self:
"""Enter context manager.
Expand All @@ -350,7 +355,6 @@ def __exit__(self, exc_type: Optional[type[BaseException]], exc: Optional[BaseEx
exc (BaseException, optional): Exc.
traceback (TracebackType, optional): Traceback.
"""
self._check_event_set()
self.finish()


Expand Down Expand Up @@ -418,7 +422,6 @@ def encode_joint_shard(self) -> bytes:
raise NotImplementedError

def flush_shard(self) -> None:
self._check_event_set()

raw_data_basename, zip_data_basename = self._name_next_shard()
raw_data = self.encode_joint_shard()
Expand Down Expand Up @@ -501,7 +504,6 @@ def encode_split_shard(self) -> tuple[bytes, bytes]:
raise NotImplementedError

def flush_shard(self) -> None:
self._check_event_set()

raw_data_basename, zip_data_basename = self._name_next_shard()
raw_meta_basename, zip_meta_basename = self._name_next_shard('meta')
Expand Down

0 comments on commit 2b2c0fb

Please sign in to comment.