Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ErrorHandler DLQ API to Python #31856

Merged
merged 8 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
self.contains_external_transforms = False

self._display_data = display_data or {}
self._error_handlers = []

def display_data(self):
# type: () -> Dict[str, Any]
Expand All @@ -258,6 +259,9 @@ def allow_unsafe_triggers(self):
# type: () -> bool
return self._options.view_as(TypeOptions).allow_unsafe_triggers

def _register_error_handler(self, error_handler):
self._error_handlers.append(error_handler)

def _current_transform(self):
# type: () -> AppliedPTransform

Expand Down Expand Up @@ -531,6 +535,9 @@ def run(self, test_runner_api='AUTO'):

"""Runs the pipeline. Returns whatever our runner returns after running."""

for error_handler in self._error_handlers:
error_handler.verify_closed()

# Records whether this pipeline contains any cross-language transforms.
self.contains_external_transforms = (
ExternalTransformFinder.contains_external_transforms(self))
Expand Down
23 changes: 22 additions & 1 deletion sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def with_exception_handling(
threshold=1,
threshold_windowing=None,
timeout=None,
error_handler=None,
on_failure_callback: typing.Optional[typing.Callable[
[Exception, typing.Any], None]] = None):
"""Automatically provides a dead letter output for skipping bad records.
Expand Down Expand Up @@ -1622,6 +1623,8 @@ def with_exception_handling(
defaults to the windowing of the input.
timeout: If the element has not finished processing in timeout seconds,
raise a TimeoutError. Defaults to None, meaning no time limit.
error_handler: An ErrorHandler that should be used to consume the bad
records, rather than returning the good and bad records as a tuple.
on_failure_callback: If an element fails or times out,
on_failure_callback will be invoked. It will receive the exception
and the element being processed in as args. In case of a timeout,
Expand All @@ -1642,8 +1645,20 @@ def with_exception_handling(
threshold,
threshold_windowing,
timeout,
error_handler,
on_failure_callback)

def with_error_handler(self, error_handler, **exception_handling_kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also have a unit test that would show how the **exception_handling_kwargs can be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Done.

"""An alias for `with_exception_handling(error_handler=error_handler, ...)`

This is provided to fit the general ErrorHandler conventions.
"""
if error_handler is None:
return self
else:
return self.with_exception_handling(
error_handler=error_handler, **exception_handling_kwargs)

def default_type_hints(self):
return self.fn.get_type_hints()

Expand Down Expand Up @@ -2242,6 +2257,7 @@ def __init__(
threshold,
threshold_windowing,
timeout,
error_handler,
on_failure_callback):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
Expand All @@ -2256,6 +2272,7 @@ def __init__(
self._threshold = threshold
self._threshold_windowing = threshold_windowing
self._timeout = timeout
self._error_handler = error_handler
self._on_failure_callback = on_failure_callback

def expand(self, pcoll):
Expand Down Expand Up @@ -2306,7 +2323,11 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):
_ = bad_count_pcoll | Map(
check_threshold, input_count_view, self._threshold)

return result
if self._error_handler:
self._error_handler.add_error_pcollection(result[self._dead_letter_tag])
return result[self._main_tag]
else:
return result


class _ExceptionHandlingWrapperDoFn(DoFn):
Expand Down
126 changes: 126 additions & 0 deletions sdks/python/apache_beam/transforms/error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Utilities for gracefully handling errors and excluding bad elements."""

import traceback

from apache_beam import transforms


class ErrorHandler:
"""ErrorHandlers are used to skip and otherwise process bad records.

Error handlers allow one to implement the "dead letter queue" pattern in
a fluent manner, disaggregating the error processing specification from
the main processing chain.

This is typically used as follows::

with error_handling.ErrorHandler(WriteToSomewhere(...)) as error_handler:
result = pcoll | SomeTransform().with_error_handler(error_handler)

in which case errors encountered by `SomeTransform()`` in processing pcoll
will be written by the PTransform `WriteToSomewhere(...)` and excluded from
`result` rather than failing the pipeline.

To implement `with_error_handling` on a PTransform, one caches the provided
error handler for use in `expand`. During `expand()` one can invoke
`error_handler.add_error_pcollection(...)` any number of times with
PCollections containing error records to be processed by the given error
handler, or (if applicable) simply invoke `with_error_handling(...)` on any
subtransforms.

The `with_error_handling` should accept `None` to indicate that error handling
is not enabled (and make implementation-by-forwarding-error-handlers easier).
In this case, any non-recoverable errors should fail the pipeline (e.g.
propagate exceptions in `process` methods) rather than silently ignore errors.
"""
def __init__(self, consumer):
self._consumer = consumer
self._creation_traceback = traceback.format_stack()[-2]
self._error_pcolls = []
self._closed = False

def __enter__(self):
self._error_pcolls = []
self._closed = False
return self

def __exit__(self, *exec_info):
if exec_info[0] is None:
self.close()

def close(self):
"""Indicates all error-producing operations have reported any errors.

Invokes the provided error consuming PTransform on any provided error
PCollections.
"""
self._output = (
tuple(self._error_pcolls) | transforms.Flatten() | self._consumer)
self._closed = True

def output(self):
"""Returns result of applying the error consumer to the error pcollections.
"""
if not self._closed:
raise RuntimeError(
"Cannot access the output of an error handler "
"until it has been closed.")
return self._output

def add_error_pcollection(self, pcoll):
"""Called by a class implementing error handling on the error records.
"""
pcoll.pipeline._register_error_handler(self)
self._error_pcolls.append(pcoll)

def verify_closed(self):
"""Called at end of pipeline construction to ensure errors are not ignored.
"""
if not self._closed:
raise RuntimeError(
"Unclosed error handler initialized at %s" % self._creation_traceback)


class _IdentityPTransform(transforms.PTransform):
def expand(self, pcoll):
return pcoll


class CollectingErrorHandler(ErrorHandler):
"""An ErrorHandler that simply collects all errors for further processing.

This ErrorHandler requires the set of errors be retrieved via `output()`
and consumed (or explicitly discarded).
"""
def __init__(self):
super().__init__(_IdentityPTransform())
self._creation_traceback = traceback.format_stack()[-2]
self._output_accessed = False

def output(self):
self._output_accessed = True
return super().output()

def verify_closed(self):
if not self._output_accessed:
raise RuntimeError(
"CollectingErrorHandler requires the output to be retrieved. "
"Initialized at %s" % self._creation_traceback)
return super().verify_closed()
148 changes: 148 additions & 0 deletions sdks/python/apache_beam/transforms/error_handling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import unittest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import error_handling


class PTransformWithErrors(beam.PTransform):
def __init__(self, limit):
self._limit = limit
self._error_handler = None

def with_error_handler(self, error_handler):
self._error_handler = error_handler
return self

def expand(self, pcoll):
limit = self._limit

def process(element):
if len(element) < limit:
return element.title()
else:
return beam.pvalue.TaggedOutput('bad', element)

def raise_on_everything(element):
raise ValueError(element)

good, bad = pcoll | beam.Map(process).with_outputs('bad', main='good')
if self._error_handler:
self._error_handler.add_error_pcollection(bad)
else:
# Will throw an exception if there are any bad elements.
_ = bad | beam.Map(raise_on_everything)
return good


def exception_throwing_map(x, limit):
if len(x) > limit:
raise ValueError(x)
else:
return x.title()


class ErrorHandlingTest(unittest.TestCase):
def test_error_handling(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x)) as error_handler:
result = pcoll | PTransformWithErrors(3).with_error_handler(
error_handler)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

def test_error_handling_pardo(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x[0])) as error_handler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we access here the first element in the array (Map(lambda x: "error: %s" % x[0]))) and in the test above we just access the entire element (Map(lambda x: "error: %s" % x))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTransformWithErrors.with_error_handling() returns a PCollection of elements as errors, whereas the standard ParDo(...).with_error_handling() returns the Python equivalent of bad records that attach the bad elements to the exception thrown.

result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(error_handler)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

def test_error_handling_pardo_with_exception_handling_kwargs(self):
def side_effect(*args):
beam._test_error_handling_pardo_with_exception_handling_kwargs_val = True

def check_side_effect():
return getattr(
beam,
'_test_error_handling_pardo_with_exception_handling_kwargs_val',
False)

self.assertFalse(check_side_effect())

with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.ErrorHandler(
beam.Map(lambda x: "error: %s" % x[0])) as error_handler:
result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(
error_handler, on_failure_callback=side_effect)
error_pcoll = error_handler.output()

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['error: cccc']), label='CheckBad')

self.assertTrue(check_side_effect())

def test_error_on_unclosed_error_handler(self):
with self.assertRaisesRegex(RuntimeError, r'.*Unclosed error handler.*'):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
# Use this outside of a context to allow it to remain unclosed.
error_handler = error_handling.ErrorHandler(beam.Map(lambda x: x))
_ = pcoll | PTransformWithErrors(3).with_error_handler(error_handler)

def test_collecting_error_handler(self):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.CollectingErrorHandler() as error_handler:
result = pcoll | beam.Map(
exception_throwing_map, limit=3).with_error_handler(error_handler)
error_pcoll = error_handler.output() | beam.Map(lambda x: x[0])

assert_that(result, equal_to(['A', 'Bb']), label='CheckGood')
assert_that(error_pcoll, equal_to(['cccc']), label='CheckBad')

def test_error_on_collecting_error_handler_without_output_retrieval(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add a unit test with the CollectingErrorHandler that is closed but not consumed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is closed (due to the context) but not consumed.

with self.assertRaisesRegex(
RuntimeError,
r'.*CollectingErrorHandler requires the output to be retrieved.*'):
with beam.Pipeline() as p:
pcoll = p | beam.Create(['a', 'bb', 'cccc'])
with error_handling.CollectingErrorHandler() as error_handler:
_ = pcoll | beam.Map(
exception_throwing_map,
limit=3).with_error_handler(error_handler)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Loading