From be052ec45feb4f37f96a08d4623f2275c58d6937 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 31 May 2023 15:53:37 -0400 Subject: [PATCH] Add docs on pre/post processing operations and dlq support (#26772) * Add docs on pre/post processing operations and dlq support * Apply suggestions from code review Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Add description to RunInference Patterns * Wording Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --------- Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- .../sdks/python-machine-learning.md | 76 ++++++++++++++++--- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/website/www/site/content/en/documentation/sdks/python-machine-learning.md b/website/www/site/content/en/documentation/sdks/python-machine-learning.md index ccd64c3412f6f..5e0cf483ff3ea 100644 --- a/website/www/site/content/en/documentation/sdks/python-machine-learning.md +++ b/website/www/site/content/en/documentation/sdks/python-machine-learning.md @@ -192,7 +192,12 @@ with pipeline as p: For more information on resource hints, see [Resource hints](/documentation/runtime/resource-hints/). -## Use a keyed ModelHandler +## RunInference Patterns + +This section suggests patterns and best practices that you can use to make your inference pipelines simpler, +more robust, and more efficient. + +### Use a keyed ModelHandler If a key is attached to the examples, wrap the `KeyedModelHandler` around the `ModelHandler` object: @@ -212,7 +217,7 @@ If you are unsure if your data is keyed, you can also use `MaybeKeyedModelHandle For more information, see [`KeyedModelHander`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler). -## Use the `PredictionResult` object +### Use the `PredictionResult` object When doing a prediction in Apache Beam, the output `PCollection` includes both the keys of the input examples and the inferences. Including both these items in the output allows you to find the input that determined the predictions. @@ -245,12 +250,7 @@ from apache_beam.ml.inference.base import PredictionResult For more information, see the [`PredictionResult` documentation](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L65). -## Run a machine learning pipeline - -For detailed instructions explaining how to build and run a Python pipeline that uses ML models, see the -[Example RunInference API pipelines](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference) on GitHub. - -## Automatic model refresh +### Automatic model refresh To automatically update the models used with the RunInference `PTransform` without stopping the Beam pipeline, pass a [`ModelMetadata`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.ModelMetadata) side input `PCollection` to the RunInference input parameter `model_metadata_pcoll`. `ModelMetdata` is a `NamedTuple` containing: @@ -267,7 +267,65 @@ The side input `PCollection` must follow the [`AsSingleton`](https://beam.apache **Note**: If the main `PCollection` emits inputs and a side input has yet to receive inputs, the main `PCollection` is buffered until there is an update to the side input. This could happen with global windowed side inputs with data driven triggers, such as `AfterCount` and `AfterProcessingTime`. Until the side input is updated, emit the default or initial model ID that is used to pass the respective `ModelHandler` as a side input. -## Beam Java SDK support +### Preprocess and postprocess your records + +With RunInference, you can add preprocessing and postprocessing operations to your transform. +To apply preprocessing operations, use `with_preprocess_fn` on your model handler: + +``` +inference = pcoll | RunInference(model_handler.with_preprocess_fn(lambda x : do_something(x))) +``` + +To apply postprocessing operations, use `with_postprocess_fn` on your model handler: + +``` +inference = pcoll | RunInference(model_handler.with_postprocess_fn(lambda x : do_something_to_result(x))) +``` + +You can also chain multiple pre- and postprocessing operations: + +``` +inference = pcoll | RunInference( + model_handler.with_preprocess_fn( + lambda x : do_something(x) + ).with_preprocess_fn( + lambda x : do_something_else(x) + ).with_postprocess_fn( + lambda x : do_something_after_inference(x) + ).with_postprocess_fn( + lambda x : do_something_else_after_inference(x) + )) +``` + +The preprocessing function is run before batching and inference. This function maps your input `PCollection` +to the base input type of the model handler. If you apply multiple preprocessing functions, they run on your original +`PCollection` in the order of last applied to first applied. + +The postprocessing function runs after inference. This function maps the output type of the base model handler +to your desired output type. If you apply multiple postprocessing functions, they run on your original +inference result in the order of first applied to last applied. + +### Handle errors while using RunInference + +To handle errors robustly while using RunInference, you can use a _dead-letter queue_. The dead-letter queue outputs failed records into a separate `PCollection` for further processing. +This `PCollection` can then be analyzed and sent to a storage system, where it can be reviewed and resubmitted to the pipeline, or discarded. +RunInference has built-in support for dead-letter queues. You can use a dead-letter queue by applying `with_exception_handling` to your RunInference transform: + +``` +main, other = pcoll | RunInference(model_handler).with_exception_handling() +other.failed_inferences | beam.Map(print) # insert logic to handle failed records here +``` + +You can also apply this pattern to RunInference transforms with associated pre- and postprocessing operations: + +``` +main, other = pcoll | RunInference(model_handler.with_preprocess_fn(f1).with_postprocess_fn(f2)).with_exception_handling() +other.failed_preprocessing[0] | beam.Map(print) # handles failed preprocess operations, indexed in the order in which they were applied +other.failed_inferences | beam.Map(print) # handles failed inferences +other.failed_postprocessing[0] | beam.Map(print) # handles failed postprocess operations, indexed in the order in which they were applied +``` + +### Run inference from a Java pipeline The RunInference API is available with the Beam Java SDK versions 2.41.0 and later through Apache Beam's [Multi-language Pipelines framework](/documentation/programming-guide/#multi-language-pipelines). For information about the Java wrapper transform, see [RunInference.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java). To try it out, see the [Java Sklearn Mnist Classification example](https://github.com/apache/beam/tree/master/examples/multi-language). Additionally, see [Using RunInference from Java SDK](https://beam.apache.org/documentation/ml/multi-language-inference/) for an example of a composite Python transform that uses the RunInference API along with preprocessing and postprocessing from a Beam Java SDK pipeline.