Skip to content

Commit

Permalink
Throw exception when event.is_set() in all places checked by Writers
Browse files Browse the repository at this point in the history
  • Loading branch information
srowen committed Aug 14, 2024
1 parent 4465be7 commit 68508ae
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ def flush_shard(self) -> None:
"""Flush cached samples to storage, creating a new shard."""
raise NotImplementedError

def _check_event_set(self) -> None:
"""Check if the event is set and raise an exception."""
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
raise Exception('One of the threads failed. Check other traceback for more ' +
'details.')

def write(self, sample: dict[str, Any]) -> None:
"""Write a sample.
Expand All @@ -253,12 +262,7 @@ def write(self, sample: dict[str, Any]) -> None:
Args:
sample (Dict[str, Any]): Sample dict.
"""
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
raise Exception('One of the threads failed. Check other traceback for more ' +
'details.')
self._check_event_set()
# Execute the task if there is no exception in any of the async threads.
new_sample = self.encode_sample(sample)
new_sample_size = len(new_sample) + self.extra_bytes_per_sample
Expand All @@ -272,11 +276,7 @@ def _write_index(self) -> None:
"""Write the index, having written all the shards."""
if self.new_samples:
raise RuntimeError('Internal error: not all samples have been written.')
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return
self._check_event_set()
basename = get_index_basename()
filename = os.path.join(self.local, basename)
obj = {
Expand Down Expand Up @@ -350,11 +350,7 @@ def __exit__(self, exc_type: Optional[type[BaseException]], exc: Optional[BaseEx
exc (BaseException, optional): Exc.
traceback (TracebackType, optional): Traceback.
"""
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return
self._check_event_set()
self.finish()


Expand Down Expand Up @@ -422,11 +418,7 @@ def encode_joint_shard(self) -> bytes:
raise NotImplementedError

def flush_shard(self) -> None:
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return
self._check_event_set()

raw_data_basename, zip_data_basename = self._name_next_shard()
raw_data = self.encode_joint_shard()
Expand Down Expand Up @@ -509,11 +501,7 @@ def encode_split_shard(self) -> tuple[bytes, bytes]:
raise NotImplementedError

def flush_shard(self) -> None:
if self.event.is_set():
# Shutdown the executor and cancel all the pending futures due to exception in one of
# the threads.
self.cancel_future_jobs()
return
self._check_event_set()

raw_data_basename, zip_data_basename = self._name_next_shard()
raw_meta_basename, zip_meta_basename = self._name_next_shard('meta')
Expand Down

0 comments on commit 68508ae

Please sign in to comment.