Skip to content

Commit

Permalink
Add collecting error handler.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb committed Jul 11, 2024
1 parent c4be92f commit 34e28f3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
30 changes: 29 additions & 1 deletion sdks/python/apache_beam/transforms/error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def close(self):
self._closed = True

def output(self):
"""Returns
"""Returns result of applying the error consumer to the error pcollections.
"""
if not self._closed:
raise RuntimeError(
Expand All @@ -96,3 +96,31 @@ def verify_closed(self):
if not self._closed:
raise RuntimeError(
"Unclosed error handler initialized at %s" % self._creation_traceback)


class _IdentityPTransform(transforms.PTransform):
def expand(self, pcoll):
return pcoll


class CollectingErrorHandler(ErrorHandler):
"""An ErrorHandler that simply collects all errors for further processing.
This ErrorHandler requires the set of errors be retrieved via `output()`
and consumed (or explicitly discarded).
"""
def __init__(self):
super().__init__(_IdentityPTransform())
self._creation_traceback = traceback.format_stack()[-2]
self._output_accessed = False

def output(self):
self._output_accessed = True
return super().output()

def verify_closed(self):
if not self._output_accessed:
raise RuntimeError(
"CollectingErrorHandler requires the output to be retrieved. "
"Initialized at %s" % self._creation_traceback)
return super().verify_closed()
37 changes: 30 additions & 7 deletions sdks/python/apache_beam/transforms/error_handling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def raise_on_everything(element):
return good


def exception_throwing_map(x, limit):
if len(x) > limit:
raise ValueError(x)
else:
return x.title()


class ErrorHandlingTest(unittest.TestCase):
def test_error_handling(self):
with beam.Pipeline() as p:
Expand All @@ -68,18 +75,12 @@ def test_error_handling(self):
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

def test_error_handling_pardo(self):
def my_map(x, limit):
if len(x) > limit:
raise ValueError(x)
else:
return x.title()

with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x[0])) as error_handler:
result = pcoll | beam.Map(
my_map, limit=3).with_error_handler(error_handler)
exception_throwing_map, limit=3).with_error_handler(error_handler)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
Expand All @@ -92,3 +93,25 @@ def test_error_on_unclosed_error_handler(self):
# Use this outside of a context to allow it to remain unclosed.
error_handler = error_handling.ErrorHandler(beam.Map(lambda x: x))
_ = pcoll | PTransformWithErrors(3).with_error_handler(error_handler)

def test_collecting_error_handler(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.CollectingErrorHandler() as error_handler:
result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(error_handler)
error_pcoll = error_handler.output() | beam.Map(lambda x: x[0])

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['cccc']), label='CheckBad')

def test_error_on_collecting_error_handler_without_output_retrieval(self):
with self.assertRaisesRegex(
RuntimeError,
r'.*CollectingErrorHandler requires the output to be retrieved.*'):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.CollectingErrorHandler() as error_handler:
result = pcoll | beam.Map(
exception_throwing_map,
limit=3).with_error_handler(error_handler)

0 comments on commit 34e28f3

Please sign in to comment.