Skip to content

Commit

Permalink
[jax2tf] Added instructions for using OSS TensorFlow model server.
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Mar 17, 2021
1 parent 25704a0 commit dacc28c
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 6 deletions.
8 changes: 4 additions & 4 deletions jax/experimental/jax2tf/examples/mnist_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def train(train_ds, test_ds, num_epochs, with_classifier=True):
test_acc = PureJaxMNIST.accuracy(PureJaxMNIST.predict, params, test_ds)
logging.info(
f"{PureJaxMNIST.name}: Epoch {epoch} in {epoch_time:0.2f} sec")
logging.info(f"{PureJaxMNIST.name}: Training set accuracy {train_acc}")
logging.info(f"{PureJaxMNIST.name}: Test set accuracy {test_acc}")
logging.info(f"{PureJaxMNIST.name}: Training set accuracy {100. * train_acc:0.2f}%")
logging.info(f"{PureJaxMNIST.name}: Test set accuracy {100. * test_acc:0.2f}%")

return (functools.partial(
PureJaxMNIST.predict, with_classifier=with_classifier), params)
Expand Down Expand Up @@ -269,8 +269,8 @@ def train(train_ds, test_ds, num_epochs, with_classifier=True):
test_acc = PureJaxMNIST.accuracy(FlaxMNIST.predict, optimizer.target,
test_ds)
logging.info(f"{FlaxMNIST.name}: Epoch {epoch} in {epoch_time:0.2f} sec")
logging.info(f"{FlaxMNIST.name}: Training set accuracy {train_acc}")
logging.info(f"{FlaxMNIST.name}: Test set accuracy {test_acc}")
logging.info(f"{FlaxMNIST.name}: Training set accuracy {100. * train_acc:0.2f}%")
logging.info(f"{FlaxMNIST.name}: Test set accuracy {100. * test_acc:0.2f}%")

# See discussion in README.md for packaging Flax models for conversion
predict_fn = functools.partial(FlaxMNIST.predict,
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/examples/saved_model_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore

flags.DEFINE_string("model", "mnist_flax",
"Which model to use: mnist_flax, mnist_pure_jax.")
flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
"Which model to use.")
flags.DEFINE_boolean("model_classifier_layer", True,
("The model should include the classifier layer, or just "
"the last layer of logits. Set this to False when you "
Expand Down
126 changes: 126 additions & 0 deletions jax/experimental/jax2tf/examples/serving/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
Using jax2tf with TensorFlow serving
====================================

This is a supplement to the
[examples/README.md](https://g3doc.corp.google.com/third_party/py/jax/experimental/jax2tf/examples/README.md)
with example code and
instructions for using `jax2tf` with the OSS TensorFlow model server.
For Google-internal versions of model server, see the `internal` subdirectory.

The goal of `jax2tf` is to convert JAX functions
into Python functions that behave as if they had been written with TensorFlow.
These functions can be tracerd and saved in a SavedModel using **standard TensorFlow
code, so the user has full control over what metadata is saved in the
SavedModel**.

The only difference in the SavedModel produced with jax2tf is that the
function graphs may contain
[XLA TF ops](http://g3doc/third_party/py/jax/experimental/jax2tf/README.md#caveats)
that require enabling XLA for execution in the model server. This
is achieved using a command-line flag. There are no other differences compared
to using SavedModel produced by TensorFlow.

This serving example uses
[saved_model_main.py](http://google3/third_party/py/jax/experimental/jax2tf/examples/saved_model_main.py)
for saving the SavedModel and adds code specific to interacting with the
model server:
[model_server_request.py](http://google3/third_party/py/jax/experimental/jax2tf/examples/serving/model_server_request.py).

0. *Set up JAX and TensorFlow serving*.

If you have already installed JAX and TensorFlow serving, you can skip most of these steps, but do set the
environment variables `JAX2TF_EXAMPLES` and `DOCKER_IMAGE`.

The following will clone locally a copy of the JAX sources, and will install the `jax`, `jaxlib`, and `flax` packages.
We also need to install TensorFlow for the `jax2tf` feature and the rest of this example.
We use the `tf_nightly` package to get an up-to-date version.

```shell
git clone https://github.com/google/jax
JAX2TF_EXAMPLES=$(pwd)/jax/jax/experimental/jax2tf/examples
pip install -e jax
pip install flax jaxlib tensorflow_datasets tensorflow_serving_api tf_nightly
```

We then install [TensorFlow serving in a Docker image](https://www.tensorflow.org/tfx/serving/docker),
again using a recent "nightly" version. Install Docker and then run:

```shell
DOCKER_IMAGE=tensorflow/serving:nightly
docker pull ${DOCKER_IMAGE}
```

1. *Set some variables*.

```shell
# Shortcuts
# Where to save SavedModels
MODEL_PATH=/tmp/jax2tf/saved_models
# The example model. The options are "mnist_flax" and "mnist_pure_jax"
MODEL=mnist_flax
# The batch size for the SavedModel.
SERVING_BATCH_SIZE=1
# Increment this when you make changes to the model parameters after the
# initial model generation (Step 1 below).
MODEL_VERSION=$(( 1 + ${MODEL_VERSION:-0} ))
```

2. *Train and export a model.* You will use `saved_model_main.py` to train
and export a SavedModel (see more details in README.md). Execute:

```shell
python ${JAX2TF_EXAMPLES}/saved_model_main.py --model=${MODEL} \
--model_path=${MODEL_PATH} --model_version=${MODEL_VERSION} \
--serving_batch_size=${SERVING_BATCH_SIZE} \
--compile_model \
--noshow_model
```

The SavedModel will be in ${MODEL_PATH}/${MODEL}/${MODEL_VERSION}.
You can inspect the SavedModel.

```shell
saved_model_cli show --all --dir ${MODEL_PATH}/${MODEL}/${MODEL_VERSION}
```

3. *Start a local model server* with XLA compilation enabled. Execute:

```shell
docker run -p 8500:8500 -p 8501:8501 \
--mount type=bind,source=${MODEL_PATH}/${MODEL}/,target=/models/${MODEL} \
-e MODEL_NAME=${MODEL} -t --rm --name=serving ${DOCKER_IMAGE} \
--xla_cpu_compilation_enabled=true &
```

Note that we are forwarding the ports 8500 (for the gRPC server) and
8501 (for the HTTP REST server).

You do not need to redo this step if you change the model parameters and
regenerate the SavedModel, as long as you bumped the ${MODEL_VERSION}.
The running model server will automatically load newer model versions.


4. *Send RPC requests.* Execute:

```shell
python ${JAX2TF_EXAMPLES}/serving/model_server_request.py --model_spec_name=${MODEL} \
--use_grpc --prediction_service_addr=localhost:8500 \
--serving_batch_size=${SERVING_BATCH_SIZE} \
--count_images=128
```

If you see an error `Input to reshape is a tensor with 12544 values, but the requested shape has 784`
then your serving batch size is 16 (= 12544 / 784) while the model loaded
in the model server has batch size 1 (= 784 / 784). You should check that
you are using the same --serving_batch_size for the model generation and
for sending the requests.

5. *Experiment with different models and batch sizes*.

- You can set `MODEL=mnist_pure_jax` to use a simpler model, using just
pure JAX, then restart from Step 1.

- You can change the batch size at which the model is converted from JAX
and saved. Set `SERVING_BATCH_SIZE=16` and restart from Step 2.
In Step 4, you should pass a `--count_images`
parameter that is a multiple of the serving batch size you choose.
127 changes: 127 additions & 0 deletions jax/experimental/jax2tf/examples/serving/model_server_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demonstrates using jax2tf with TensorFlow model server.
See README.md for instructions.
"""
import grpc
import json
import logging
import requests

from absl import app
from absl import flags

from jax.experimental.jax2tf.examples import mnist_lib

import numpy as np
import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc


FLAGS = flags.FLAGS

flags.DEFINE_boolean(
"use_grpc", True,
"Use the gRPC API (default), or the HTTP REST API.")

flags.DEFINE_string(
"model_spec_name", "",
"The name you used to export your model to model server (e.g., mnist_flax).")

flags.DEFINE_string(
"prediction_service_addr",
"localhost:8500",
"Stubby endpoint for the prediction service. If you serve your model "
"locally using TensorFlow model server, then you can use \"locahost:8500\""
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.")

flags.DEFINE_integer("serving_batch_size", 1,
"Batch size for the serving request. Must match the "
"batch size at which the model was saved. Must divide "
"--count_images",
lower_bound=1)
flags.DEFINE_integer("count_images", 16,
"How many images to test.",
lower_bound=1)


def serving_call_mnist(images):
"""Send an RPC or REST request to the model server.
Args:
images: A numpy.ndarray of shape [B, 28, 28, 1] with the batch of images to
perform inference on.
Returns:
A numpy.ndarray of shape [B, 10] with the one-hot inference response.
"""
if FLAGS.use_grpc:
channel = grpc.insecure_channel(FLAGS.prediction_service_addr)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

request = predict_pb2.PredictRequest()
request.model_spec.name = FLAGS.model_spec_name
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
# You can see the name of the input ("inputs") in the SavedModel dump.
request.inputs["inputs"].CopyFrom(
tf.make_tensor_proto(images, dtype=images.dtype, shape=images.shape))
response = stub.Predict(request)
# We could also use response.outputs["output_0"], where "output_0" is the
# name of the output (which you can see in the SavedModel dump.)
# Alternatively, we just get the first output.
outputs, = response.outputs.values()
return tf.make_ndarray(outputs)
else:
# Use the HTTP REST api
images_json = json.dumps(images.tolist())
# You can see the name of the input ("inputs") in the SavedModel dump.
data = f'{{"inputs": {images_json}}}'
predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict"
response = requests.post(predict_url, data=data)
if response.status_code != 200:
msg = (f"Received error response {response.status_code} from model "
f"server: {response.text}")
raise ValueError(msg)
outputs = response.json()["outputs"]
return np.array(outputs)


def main(_):
if FLAGS.count_images % FLAGS.serving_batch_size != 0:
raise ValueError(f"The count_images ({FLAGS.count_images}) must be a "
"multiple of "
f"serving_batch_size ({FLAGS.serving_batch_size})")
test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
batch_size=FLAGS.serving_batch_size)
images_and_labels = tfds.as_numpy(test_ds.take(
FLAGS.count_images // FLAGS.serving_batch_size))

accurate_count = 0
for batch_idx, (images, labels) in enumerate(images_and_labels):
predictions_one_hot = serving_call_mnist(images)
predictions_digit = np.argmax(predictions_one_hot, axis=1)
labels_digit = np.argmax(labels, axis=1)
accurate_count += np.sum(labels_digit == predictions_digit)
running_accuracy = (
100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size)
logging.info(
f" predicted digits = {predictions_digit} labels {labels_digit}. "
f"Running accuracy {running_accuracy:.3f}%")


if __name__ == "__main__":
app.run(main)

0 comments on commit dacc28c

Please sign in to comment.