-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jax2tf] Added instructions for using OSS TensorFlow model server.
- Loading branch information
Showing
4 changed files
with
259 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
Using jax2tf with TensorFlow serving | ||
==================================== | ||
|
||
This is a supplement to the | ||
[examples/README.md](https://g3doc.corp.google.com/third_party/py/jax/experimental/jax2tf/examples/README.md) | ||
with example code and | ||
instructions for using `jax2tf` with the OSS TensorFlow model server. | ||
For Google-internal versions of model server, see the `internal` subdirectory. | ||
|
||
The goal of `jax2tf` is to convert JAX functions | ||
into Python functions that behave as if they had been written with TensorFlow. | ||
These functions can be tracerd and saved in a SavedModel using **standard TensorFlow | ||
code, so the user has full control over what metadata is saved in the | ||
SavedModel**. | ||
|
||
The only difference in the SavedModel produced with jax2tf is that the | ||
function graphs may contain | ||
[XLA TF ops](http://g3doc/third_party/py/jax/experimental/jax2tf/README.md#caveats) | ||
that require enabling XLA for execution in the model server. This | ||
is achieved using a command-line flag. There are no other differences compared | ||
to using SavedModel produced by TensorFlow. | ||
|
||
This serving example uses | ||
[saved_model_main.py](http://google3/third_party/py/jax/experimental/jax2tf/examples/saved_model_main.py) | ||
for saving the SavedModel and adds code specific to interacting with the | ||
model server: | ||
[model_server_request.py](http://google3/third_party/py/jax/experimental/jax2tf/examples/serving/model_server_request.py). | ||
|
||
0. *Set up JAX and TensorFlow serving*. | ||
|
||
If you have already installed JAX and TensorFlow serving, you can skip most of these steps, but do set the | ||
environment variables `JAX2TF_EXAMPLES` and `DOCKER_IMAGE`. | ||
|
||
The following will clone locally a copy of the JAX sources, and will install the `jax`, `jaxlib`, and `flax` packages. | ||
We also need to install TensorFlow for the `jax2tf` feature and the rest of this example. | ||
We use the `tf_nightly` package to get an up-to-date version. | ||
|
||
```shell | ||
git clone https://github.com/google/jax | ||
JAX2TF_EXAMPLES=$(pwd)/jax/jax/experimental/jax2tf/examples | ||
pip install -e jax | ||
pip install flax jaxlib tensorflow_datasets tensorflow_serving_api tf_nightly | ||
``` | ||
|
||
We then install [TensorFlow serving in a Docker image](https://www.tensorflow.org/tfx/serving/docker), | ||
again using a recent "nightly" version. Install Docker and then run: | ||
|
||
```shell | ||
DOCKER_IMAGE=tensorflow/serving:nightly | ||
docker pull ${DOCKER_IMAGE} | ||
``` | ||
|
||
1. *Set some variables*. | ||
|
||
```shell | ||
# Shortcuts | ||
# Where to save SavedModels | ||
MODEL_PATH=/tmp/jax2tf/saved_models | ||
# The example model. The options are "mnist_flax" and "mnist_pure_jax" | ||
MODEL=mnist_flax | ||
# The batch size for the SavedModel. | ||
SERVING_BATCH_SIZE=1 | ||
# Increment this when you make changes to the model parameters after the | ||
# initial model generation (Step 1 below). | ||
MODEL_VERSION=$(( 1 + ${MODEL_VERSION:-0} )) | ||
``` | ||
|
||
2. *Train and export a model.* You will use `saved_model_main.py` to train | ||
and export a SavedModel (see more details in README.md). Execute: | ||
|
||
```shell | ||
python ${JAX2TF_EXAMPLES}/saved_model_main.py --model=${MODEL} \ | ||
--model_path=${MODEL_PATH} --model_version=${MODEL_VERSION} \ | ||
--serving_batch_size=${SERVING_BATCH_SIZE} \ | ||
--compile_model \ | ||
--noshow_model | ||
``` | ||
|
||
The SavedModel will be in ${MODEL_PATH}/${MODEL}/${MODEL_VERSION}. | ||
You can inspect the SavedModel. | ||
|
||
```shell | ||
saved_model_cli show --all --dir ${MODEL_PATH}/${MODEL}/${MODEL_VERSION} | ||
``` | ||
|
||
3. *Start a local model server* with XLA compilation enabled. Execute: | ||
|
||
```shell | ||
docker run -p 8500:8500 -p 8501:8501 \ | ||
--mount type=bind,source=${MODEL_PATH}/${MODEL}/,target=/models/${MODEL} \ | ||
-e MODEL_NAME=${MODEL} -t --rm --name=serving ${DOCKER_IMAGE} \ | ||
--xla_cpu_compilation_enabled=true & | ||
``` | ||
|
||
Note that we are forwarding the ports 8500 (for the gRPC server) and | ||
8501 (for the HTTP REST server). | ||
|
||
You do not need to redo this step if you change the model parameters and | ||
regenerate the SavedModel, as long as you bumped the ${MODEL_VERSION}. | ||
The running model server will automatically load newer model versions. | ||
|
||
|
||
4. *Send RPC requests.* Execute: | ||
|
||
```shell | ||
python ${JAX2TF_EXAMPLES}/serving/model_server_request.py --model_spec_name=${MODEL} \ | ||
--use_grpc --prediction_service_addr=localhost:8500 \ | ||
--serving_batch_size=${SERVING_BATCH_SIZE} \ | ||
--count_images=128 | ||
``` | ||
|
||
If you see an error `Input to reshape is a tensor with 12544 values, but the requested shape has 784` | ||
then your serving batch size is 16 (= 12544 / 784) while the model loaded | ||
in the model server has batch size 1 (= 784 / 784). You should check that | ||
you are using the same --serving_batch_size for the model generation and | ||
for sending the requests. | ||
|
||
5. *Experiment with different models and batch sizes*. | ||
|
||
- You can set `MODEL=mnist_pure_jax` to use a simpler model, using just | ||
pure JAX, then restart from Step 1. | ||
|
||
- You can change the batch size at which the model is converted from JAX | ||
and saved. Set `SERVING_BATCH_SIZE=16` and restart from Step 2. | ||
In Step 4, you should pass a `--count_images` | ||
parameter that is a multiple of the serving batch size you choose. |
127 changes: 127 additions & 0 deletions
127
jax/experimental/jax2tf/examples/serving/model_server_request.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Demonstrates using jax2tf with TensorFlow model server. | ||
See README.md for instructions. | ||
""" | ||
import grpc | ||
import json | ||
import logging | ||
import requests | ||
|
||
from absl import app | ||
from absl import flags | ||
|
||
from jax.experimental.jax2tf.examples import mnist_lib | ||
|
||
import numpy as np | ||
import tensorflow as tf # type: ignore | ||
import tensorflow_datasets as tfds # type: ignore | ||
from tensorflow_serving.apis import predict_pb2 | ||
from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_boolean( | ||
"use_grpc", True, | ||
"Use the gRPC API (default), or the HTTP REST API.") | ||
|
||
flags.DEFINE_string( | ||
"model_spec_name", "", | ||
"The name you used to export your model to model server (e.g., mnist_flax).") | ||
|
||
flags.DEFINE_string( | ||
"prediction_service_addr", | ||
"localhost:8500", | ||
"Stubby endpoint for the prediction service. If you serve your model " | ||
"locally using TensorFlow model server, then you can use \"locahost:8500\"" | ||
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.") | ||
|
||
flags.DEFINE_integer("serving_batch_size", 1, | ||
"Batch size for the serving request. Must match the " | ||
"batch size at which the model was saved. Must divide " | ||
"--count_images", | ||
lower_bound=1) | ||
flags.DEFINE_integer("count_images", 16, | ||
"How many images to test.", | ||
lower_bound=1) | ||
|
||
|
||
def serving_call_mnist(images): | ||
"""Send an RPC or REST request to the model server. | ||
Args: | ||
images: A numpy.ndarray of shape [B, 28, 28, 1] with the batch of images to | ||
perform inference on. | ||
Returns: | ||
A numpy.ndarray of shape [B, 10] with the one-hot inference response. | ||
""" | ||
if FLAGS.use_grpc: | ||
channel = grpc.insecure_channel(FLAGS.prediction_service_addr) | ||
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
|
||
request = predict_pb2.PredictRequest() | ||
request.model_spec.name = FLAGS.model_spec_name | ||
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY | ||
# You can see the name of the input ("inputs") in the SavedModel dump. | ||
request.inputs["inputs"].CopyFrom( | ||
tf.make_tensor_proto(images, dtype=images.dtype, shape=images.shape)) | ||
response = stub.Predict(request) | ||
# We could also use response.outputs["output_0"], where "output_0" is the | ||
# name of the output (which you can see in the SavedModel dump.) | ||
# Alternatively, we just get the first output. | ||
outputs, = response.outputs.values() | ||
return tf.make_ndarray(outputs) | ||
else: | ||
# Use the HTTP REST api | ||
images_json = json.dumps(images.tolist()) | ||
# You can see the name of the input ("inputs") in the SavedModel dump. | ||
data = f'{{"inputs": {images_json}}}' | ||
predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict" | ||
response = requests.post(predict_url, data=data) | ||
if response.status_code != 200: | ||
msg = (f"Received error response {response.status_code} from model " | ||
f"server: {response.text}") | ||
raise ValueError(msg) | ||
outputs = response.json()["outputs"] | ||
return np.array(outputs) | ||
|
||
|
||
def main(_): | ||
if FLAGS.count_images % FLAGS.serving_batch_size != 0: | ||
raise ValueError(f"The count_images ({FLAGS.count_images}) must be a " | ||
"multiple of " | ||
f"serving_batch_size ({FLAGS.serving_batch_size})") | ||
test_ds = mnist_lib.load_mnist(tfds.Split.TEST, | ||
batch_size=FLAGS.serving_batch_size) | ||
images_and_labels = tfds.as_numpy(test_ds.take( | ||
FLAGS.count_images // FLAGS.serving_batch_size)) | ||
|
||
accurate_count = 0 | ||
for batch_idx, (images, labels) in enumerate(images_and_labels): | ||
predictions_one_hot = serving_call_mnist(images) | ||
predictions_digit = np.argmax(predictions_one_hot, axis=1) | ||
labels_digit = np.argmax(labels, axis=1) | ||
accurate_count += np.sum(labels_digit == predictions_digit) | ||
running_accuracy = ( | ||
100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size) | ||
logging.info( | ||
f" predicted digits = {predictions_digit} labels {labels_digit}. " | ||
f"Running accuracy {running_accuracy:.3f}%") | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |