Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs on pre/post processing operations and dlq support #26772

Merged
merged 4 commits into from
May 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,12 @@ with pipeline as p:

For more information on resource hints, see [Resource hints](/documentation/runtime/resource-hints/).

## Use a keyed ModelHandler
## RunInference Patterns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include text after this header describing the section. Something along the lines of:

This section provides examples of RunInference patterns.

(But hopefully more informative than that.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - let me know if you think it sounds ok: "This section suggests patterns and best practices that can be used to make your inference pipelines simpler, more robust, and more efficient."


This section suggests patterns and best practices that you can use to make your inference pipelines simpler,
more robust, and more efficient.

### Use a keyed ModelHandler

If a key is attached to the examples, wrap the `KeyedModelHandler` around the `ModelHandler` object:

Expand All @@ -212,7 +217,7 @@ If you are unsure if your data is keyed, you can also use `MaybeKeyedModelHandle

For more information, see [`KeyedModelHander`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler).

## Use the `PredictionResult` object
### Use the `PredictionResult` object

When doing a prediction in Apache Beam, the output `PCollection` includes both the keys of the input examples and the inferences. Including both these items in the output allows you to find the input that determined the predictions.

Expand Down Expand Up @@ -245,12 +250,7 @@ from apache_beam.ml.inference.base import PredictionResult

For more information, see the [`PredictionResult` documentation](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L65).

## Run a machine learning pipeline

For detailed instructions explaining how to build and run a Python pipeline that uses ML models, see the
[Example RunInference API pipelines](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference) on GitHub.

## Automatic model refresh
### Automatic model refresh
To automatically update the models used with the RunInference `PTransform` without stopping the Beam pipeline, pass a [`ModelMetadata`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.ModelMetadata) side input `PCollection` to the RunInference input parameter `model_metadata_pcoll`.

`ModelMetdata` is a `NamedTuple` containing:
Expand All @@ -267,7 +267,65 @@ The side input `PCollection` must follow the [`AsSingleton`](https://beam.apache
**Note**: If the main `PCollection` emits inputs and a side input has yet to receive inputs, the main `PCollection` is buffered until there is
an update to the side input. This could happen with global windowed side inputs with data driven triggers, such as `AfterCount` and `AfterProcessingTime`. Until the side input is updated, emit the default or initial model ID that is used to pass the respective `ModelHandler` as a side input.

## Beam Java SDK support
### Preprocess and postprocess your records

With RunInference, you can add preprocessing and postprocessing operations to your transform.
To apply preprocessing operations, use `with_preprocess_fn` on your model handler:

```
inference = pcoll | RunInference(model_handler.with_preprocess_fn(lambda x : do_something(x)))
```

To apply postprocessing operations, use `with_postprocess_fn` on your model handler:

```
inference = pcoll | RunInference(model_handler.with_postprocess_fn(lambda x : do_something_to_result(x)))
```

You can also chain multiple pre- and postprocessing operations:

```
inference = pcoll | RunInference(
model_handler.with_preprocess_fn(
lambda x : do_something(x)
).with_preprocess_fn(
lambda x : do_something_else(x)
).with_postprocess_fn(
lambda x : do_something_after_inference(x)
).with_postprocess_fn(
lambda x : do_something_else_after_inference(x)
))
```

The preprocessing function is run before batching and inference. This function maps your input `PCollection`
to the base input type of the model handler. If you apply multiple preprocessing functions, they run on your original
`PCollection` in the order of last applied to first applied.

The postprocessing function runs after inference. This function maps the output type of the base model handler
to your desired output type. If you apply multiple postprocessing functions, they run on your original
inference result in the order of first applied to last applied.

### Handle errors while using RunInference

To handle errors robustly while using RunInference, you can use a _dead-letter queue_. The dead-letter queue outputs failed records into a separate `PCollection` for further processing.
This `PCollection` can then be analyzed and sent to a storage system, where it can be reviewed and resubmitted to the pipeline, or discarded.
RunInference has built-in support for dead-letter queues. You can use a dead-letter queue by applying `with_exception_handling` to your RunInference transform:

```
main, other = pcoll | RunInference(model_handler).with_exception_handling()
other.failed_inferences | beam.Map(print) # insert logic to handle failed records here
```

You can also apply this pattern to RunInference transforms with associated pre- and postprocessing operations:

```
main, other = pcoll | RunInference(model_handler.with_preprocess_fn(f1).with_postprocess_fn(f2)).with_exception_handling()
other.failed_preprocessing[0] | beam.Map(print) # handles failed preprocess operations, indexed in the order in which they were applied
other.failed_inferences | beam.Map(print) # handles failed inferences
other.failed_postprocessing[0] | beam.Map(print) # handles failed postprocess operations, indexed in the order in which they were applied
```

### Run inference from a Java pipeline

The RunInference API is available with the Beam Java SDK versions 2.41.0 and later through Apache Beam's [Multi-language Pipelines framework](/documentation/programming-guide/#multi-language-pipelines). For information about the Java wrapper transform, see [RunInference.java](https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java). To try it out, see the [Java Sklearn Mnist Classification example](https://github.com/apache/beam/tree/master/examples/multi-language). Additionally, see [Using RunInference from Java SDK](https://beam.apache.org/documentation/ml/multi-language-inference/) for an example of a composite Python transform that uses the RunInference API along with preprocessing and postprocessing from a Beam Java SDK pipeline.

Expand Down