diff --git a/CHANGES.md b/CHANGES.md index 3113c87500546..7be9752517d58 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -64,7 +64,7 @@ ## New Features / Improvements -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Dead letter queue support added to RunInference in Python ([#24209](https://github.com/apache/beam/issues/24209)). ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 50056107702e6..5bc3bb3155784 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -339,6 +339,7 @@ def __init__( self._metrics_namespace = metrics_namespace self._model_metadata_pcoll = model_metadata_pcoll self._enable_side_input_loading = self._model_metadata_pcoll is not None + self._with_exception_handling = False # TODO(BEAM-14046): Add and link to help documentation. @classmethod @@ -368,20 +369,71 @@ def expand( # batching DoFn APIs. | beam.BatchElements(**self._model_handler.batch_elements_kwargs())) + run_inference_pardo = beam.ParDo( + _RunInferenceDoFn( + self._model_handler, + self._clock, + self._metrics_namespace, + self._enable_side_input_loading), + self._inference_args, + beam.pvalue.AsSingleton( + self._model_metadata_pcoll, + ) if self._enable_side_input_loading else None).with_resource_hints( + **resource_hints) + + if self._with_exception_handling: + run_inference_pardo = run_inference_pardo.with_exception_handling( + exc_class=self._exc_class, + use_subprocess=self._use_subprocess, + threshold=self._threshold) + return ( batched_elements_pcoll - | 'BeamML_RunInference' >> ( - beam.ParDo( - _RunInferenceDoFn( - self._model_handler, - self._clock, - self._metrics_namespace, - self._enable_side_input_loading), - self._inference_args, - beam.pvalue.AsSingleton( - self._model_metadata_pcoll, - ) if self._enable_side_input_loading else - None).with_resource_hints(**resource_hints))) + | 'BeamML_RunInference' >> run_inference_pardo) + + def with_exception_handling( + self, *, exc_class=Exception, use_subprocess=False, threshold=1): + """Automatically provides a dead letter output for skipping bad records. + This can allow a pipeline to continue successfully rather than fail or + continuously throw errors on retry when bad elements are encountered. + + This returns a tagged output with two PCollections, the first being the + results of successfully processing the input PCollection, and the second + being the set of bad batches of records (those which threw exceptions + during processing) along with information about the errors raised. + + For example, one would write:: + + good, bad = RunInference( + maybe_error_raising_model_handler + ).with_exception_handling() + + and `good` will be a PCollection of PredictionResults and `bad` will + contain a tuple of all batches that raised exceptions, along with their + corresponding exception. + + + Args: + exc_class: An exception class, or tuple of exception classes, to catch. + Optional, defaults to 'Exception'. + use_subprocess: Whether to execute the DoFn logic in a subprocess. This + allows one to recover from errors that can crash the calling process + (e.g. from an underlying library causing a segfault), but is + slower as elements and results must cross a process boundary. Note + that this starts up a long-running process that is used to handle + all the elements (until hard failure, which should be rare) rather + than a new process per element, so the overhead should be minimal + (and can be amortized if there's any per-process or per-bundle + initialization that needs to be done). Optional, defaults to False. + threshold: An upper bound on the ratio of records that can be bad before + aborting the entire pipeline. Optional, defaults to 1.0 (meaning + up to 100% of records can be bad and the pipeline will still succeed). + """ + self._with_exception_handling = True + self._exc_class = exc_class + self._use_subprocess = use_subprocess + self._threshold = threshold + return self class _MetricsCollector: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index dad18c7b9e186..da82095a4e9d6 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -48,8 +48,10 @@ def predict(self, example: int) -> int: class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): - def __init__(self, clock=None): + def __init__(self, clock=None, min_batch_size=1, max_batch_size=9999): self._fake_clock = clock + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size def load_model(self): if self._fake_clock: @@ -69,6 +71,12 @@ def run_inference( def update_model_path(self, model_path: Optional[str] = None): pass + def batch_elements_kwargs(self): + return { + 'min_batch_size': self._min_batch_size, + 'max_batch_size': self._max_batch_size + } + class FakeModelHandlerReturnsPredictionResult( base.ModelHandler[int, base.PredictionResult, FakeModel]): @@ -171,6 +179,24 @@ def test_run_inference_impl_with_maybe_keyed_examples(self): model_handler) assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') + def test_run_inference_impl_dlq(self): + with TestPipeline() as pipeline: + examples = [1, 'TEST', 3, 10, 'TEST2'] + expected_good = [2, 4, 11] + expected_bad = ['TEST', 'TEST2'] + pcoll = pipeline | 'start' >> beam.Create(examples) + good, bad = pcoll | base.RunInference( + FakeModelHandler( + min_batch_size=1, + max_batch_size=1 + )).with_exception_handling() + assert_that(good, equal_to(expected_good), label='assert:inferences') + + # bad will be in form [batch[elements], error]. Just pull out bad element. + bad_without_error = bad | beam.Map(lambda x: x[0][0]) + assert_that( + bad_without_error, equal_to(expected_bad), label='assert:failures') + def test_run_inference_impl_inference_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10]