Skip to content

Commit

Permalink
Merge pull request #31856 Add ErrorHandler DLQ API to Python.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Aug 21, 2024
2 parents 3fb4fd0 + 36e5eff commit b3a874f
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 1 deletion.
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
self.contains_external_transforms = False

self._display_data = display_data or {}
self._error_handlers = []

def display_data(self):
# type: () -> Dict[str, Any]
Expand All @@ -258,6 +259,9 @@ def allow_unsafe_triggers(self):
# type: () -> bool
return self._options.view_as(TypeOptions).allow_unsafe_triggers

def _register_error_handler(self, error_handler):
self._error_handlers.append(error_handler)

def _current_transform(self):
# type: () -> AppliedPTransform

Expand Down Expand Up @@ -531,6 +535,9 @@ def run(self, test_runner_api='AUTO'):

"""Runs the pipeline. Returns whatever our runner returns after running."""

for error_handler in self._error_handlers:
error_handler.verify_closed()

# Records whether this pipeline contains any cross-language transforms.
self.contains_external_transforms = (
ExternalTransformFinder.contains_external_transforms(self))
Expand Down
23 changes: 22 additions & 1 deletion sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def with_exception_handling(
threshold=1,
threshold_windowing=None,
timeout=None,
error_handler=None,
on_failure_callback: typing.Optional[typing.Callable[
[Exception, typing.Any], None]] = None):
"""Automatically provides a dead letter output for skipping bad records.
Expand Down Expand Up @@ -1622,6 +1623,8 @@ def with_exception_handling(
defaults to the windowing of the input.
timeout: If the element has not finished processing in timeout seconds,
raise a TimeoutError. Defaults to None, meaning no time limit.
error_handler: An ErrorHandler that should be used to consume the bad
records, rather than returning the good and bad records as a tuple.
on_failure_callback: If an element fails or times out,
on_failure_callback will be invoked. It will receive the exception
and the element being processed in as args. In case of a timeout,
Expand All @@ -1642,8 +1645,20 @@ def with_exception_handling(
threshold,
threshold_windowing,
timeout,
error_handler,
on_failure_callback)

def with_error_handler(self, error_handler, **exception_handling_kwargs):
"""An alias for `with_exception_handling(error_handler=error_handler, ...)`
This is provided to fit the general ErrorHandler conventions.
"""
if error_handler is None:
return self
else:
return self.with_exception_handling(
error_handler=error_handler, **exception_handling_kwargs)

def default_type_hints(self):
return self.fn.get_type_hints()

Expand Down Expand Up @@ -2242,6 +2257,7 @@ def __init__(
threshold,
threshold_windowing,
timeout,
error_handler,
on_failure_callback):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
Expand All @@ -2256,6 +2272,7 @@ def __init__(
self._threshold = threshold
self._threshold_windowing = threshold_windowing
self._timeout = timeout
self._error_handler = error_handler
self._on_failure_callback = on_failure_callback

def expand(self, pcoll):
Expand Down Expand Up @@ -2306,7 +2323,11 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):
_ = bad_count_pcoll | Map(
check_threshold, input_count_view, self._threshold)

return result
if self._error_handler:
self._error_handler.add_error_pcollection(result[self._dead_letter_tag])
return result[self._main_tag]
else:
return result


class _ExceptionHandlingWrapperDoFn(DoFn):
Expand Down
126 changes: 126 additions & 0 deletions sdks/python/apache_beam/transforms/error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Utilities for gracefully handling errors and excluding bad elements."""

import traceback

from apache_beam import transforms


class ErrorHandler:
"""ErrorHandlers are used to skip and otherwise process bad records.
Error handlers allow one to implement the "dead letter queue" pattern in
a fluent manner, disaggregating the error processing specification from
the main processing chain.
This is typically used as follows::
with error_handling.ErrorHandler(WriteToSomewhere(...)) as error_handler:
result = pcoll | SomeTransform().with_error_handler(error_handler)
in which case errors encountered by `SomeTransform()`` in processing pcoll
will be written by the PTransform `WriteToSomewhere(...)` and excluded from
`result` rather than failing the pipeline.
To implement `with_error_handling` on a PTransform, one caches the provided
error handler for use in `expand`. During `expand()` one can invoke
`error_handler.add_error_pcollection(...)` any number of times with
PCollections containing error records to be processed by the given error
handler, or (if applicable) simply invoke `with_error_handling(...)` on any
subtransforms.
The `with_error_handling` should accept `None` to indicate that error handling
is not enabled (and make implementation-by-forwarding-error-handlers easier).
In this case, any non-recoverable errors should fail the pipeline (e.g.
propagate exceptions in `process` methods) rather than silently ignore errors.
"""
def __init__(self, consumer):
self._consumer = consumer
self._creation_traceback = traceback.format_stack()[-2]
self._error_pcolls = []
self._closed = False

def __enter__(self):
self._error_pcolls = []
self._closed = False
return self

def __exit__(self, *exec_info):
if exec_info[0] is None:
self.close()

def close(self):
"""Indicates all error-producing operations have reported any errors.
Invokes the provided error consuming PTransform on any provided error
PCollections.
"""
self._output = (
tuple(self._error_pcolls) | transforms.Flatten() | self._consumer)
self._closed = True

def output(self):
"""Returns result of applying the error consumer to the error pcollections.
"""
if not self._closed:
raise RuntimeError(
"Cannot access the output of an error handler "
"until it has been closed.")
return self._output

def add_error_pcollection(self, pcoll):
"""Called by a class implementing error handling on the error records.
"""
pcoll.pipeline._register_error_handler(self)
self._error_pcolls.append(pcoll)

def verify_closed(self):
"""Called at end of pipeline construction to ensure errors are not ignored.
"""
if not self._closed:
raise RuntimeError(
"Unclosed error handler initialized at %s" % self._creation_traceback)


class _IdentityPTransform(transforms.PTransform):
def expand(self, pcoll):
return pcoll


class CollectingErrorHandler(ErrorHandler):
"""An ErrorHandler that simply collects all errors for further processing.
This ErrorHandler requires the set of errors be retrieved via `output()`
and consumed (or explicitly discarded).
"""
def __init__(self):
super().__init__(_IdentityPTransform())
self._creation_traceback = traceback.format_stack()[-2]
self._output_accessed = False

def output(self):
self._output_accessed = True
return super().output()

def verify_closed(self):
if not self._output_accessed:
raise RuntimeError(
"CollectingErrorHandler requires the output to be retrieved. "
"Initialized at %s" % self._creation_traceback)
return super().verify_closed()
148 changes: 148 additions & 0 deletions sdks/python/apache_beam/transforms/error_handling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import unittest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import error_handling


class PTransformWithErrors(beam.PTransform):
def __init__(self, limit):
self._limit = limit
self._error_handler = None

def with_error_handler(self, error_handler):
self._error_handler = error_handler
return self

def expand(self, pcoll):
limit = self._limit

def process(element):
if len(element) < limit:
return element.title()
else:
return beam.pvalue.TaggedOutput('bad', element)

def raise_on_everything(element):
raise ValueError(element)

good, bad = pcoll | beam.Map(process).with_outputs('bad', main='good')
if self._error_handler:
self._error_handler.add_error_pcollection(bad)
else:
# Will throw an exception if there are any bad elements.
_ = bad | beam.Map(raise_on_everything)
return good


def exception_throwing_map(x, limit):
if len(x) > limit:
raise ValueError(x)
else:
return x.title()


class ErrorHandlingTest(unittest.TestCase):
def test_error_handling(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x)) as error_handler:
result = pcoll | PTransformWithErrors(3).with_error_handler(
error_handler)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

def test_error_handling_pardo(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x[0])) as error_handler:
result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(error_handler)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

def test_error_handling_pardo_with_exception_handling_kwargs(self):
def side_effect(*args):
beam._test_error_handling_pardo_with_exception_handling_kwargs_val = True

def check_side_effect():
return getattr(
beam,
'_test_error_handling_pardo_with_exception_handling_kwargs_val',
False)

self.assertFalse(check_side_effect())

with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x[0])) as error_handler:
result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(
error_handler, on_failure_callback=side_effect)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

self.assertTrue(check_side_effect())

def test_error_on_unclosed_error_handler(self):
with self.assertRaisesRegex(RuntimeError, r'.*Unclosed error handler.*'):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
# Use this outside of a context to allow it to remain unclosed.
error_handler = error_handling.ErrorHandler(beam.Map(lambda x: x))
_ = pcoll | PTransformWithErrors(3).with_error_handler(error_handler)

def test_collecting_error_handler(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.CollectingErrorHandler() as error_handler:
result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(error_handler)
error_pcoll = error_handler.output() | beam.Map(lambda x: x[0])

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['cccc']), label='CheckBad')

def test_error_on_collecting_error_handler_without_output_retrieval(self):
with self.assertRaisesRegex(
RuntimeError,
r'.*CollectingErrorHandler requires the output to be retrieved.*'):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.CollectingErrorHandler() as error_handler:
_ = pcoll | beam.Map(
exception_throwing_map,
limit=3).with_error_handler(error_handler)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit b3a874f

Please sign in to comment.