Comparing the impact dynamic batching, static batching, and no batching have on throughput for a generative LLM inference server.
This repo uses gunicorn + flask to host a gpt-2-medium
model implemented using code from Andrej Karpathy's nanoGPT repo
Dynamic batching algorithm is modeled after the Orca paper.
Use the Instructions
section below for full tutorial on how to run experiments.
Dynamic-final.mov
The server exposes a /inference
endpoint that takes a request with a prompt and # of completion tokens to generate. The server does not support terminating a generation based on certain end tokens. A simple stats
endpoint also exists to display server stats for tracking experiment results.
app.py
is the main script that loads the model and defines the server logic.
batching.py
contains two classes: Inference
and BatchingManager
.
BatchingManager
defines how Inference
objects are scheduled for model inference. When the server starts, an inference handler thread is launched that runs either no_batching_loop
, static_batching_loop
, or dynamic_batching_loop
. These loops handle new inferences every 0.01 seconds.
New requests are enqueued using the BatchingManager's enqueue
function. Requests are transformed into Inference
objects that hold onto the request data as well as metadata used by the BatchingManager
.
This Inference
object is returned by the enqueue
function. Each Inference
object stores a reference to a unique threading.Event
object that will be used to signal when the inference has finished.
Implementations for nobatch, static, and dynamic generations can be found in the generate
folder.
These generation functions are centered around a ServerModel
object defined in the model.py
file. Dynamic batching requires a DynamicBatchingServerModel
object, which extends ServerModel
with modified attention and batch inference functions.
Client code can be found in the client
folder.
client.py
is a script that launches inference requests to the server, waits for the requests to finish, and then prints the results from the server /stats
endpoint.
data.py
contains a PromptData
class that is used by the client script. PromptData
generates the inference data used for requests.
Google Compute Engine VM Setup
The GCE VM instance needs to be configured to serve the flask server. If you only want to run the server locally, then steps 2-4 are not needed.
-
Make sure the instance has a GPU
-
In the firewalls section, make sure that HTTP Traffic and HTTPS traffic are toggled on.
-
Create a network tag that allows ingress traffic on port 8500.
-
Add this network tag to the VM configuration.
More details on creating a VM configuration that can host a flask app here
After the VM is created, ssh into the instance and clone this repo.
Install Python Requirements
- Make and activate a Python virtual environment
cd
into theeffective-batching
repo root directory- Run
pip install -r requirements.txt
Launch the app
No batching:
./server.sh nobatch
Static batching:
./server.sh static
Dynamic batching:
./server.sh dynamic
Python Requirements
Install the requirements in the requirements.txt
file.
Environment Variables
- Make a
.env
file ineffective-batching/client
folder - Set
IP=<YOUR_GCE_INSTANCE_EXTERNAL_IP>
in the file.
The client script will read from the IP
environment variable to format the request.
Run the client script
To launch numsamples
requests and display request stats, run:
python3 client.py --numsamples 100
To launch a single request, run:
python3 client.py --prompt "hi" --numtokens 10
Dynamic.mov
Dynamic batching has the best throughput and latency-per-token.
(The following table displays server stats after the client script makes ~ 1 request per second, 100 requests, # of requested tokens randomly sampled from normal distribution between 1 to 200 tokens)
request latency vs number of requested tokens - we can see dynamic batching is a lot more 'fair' compared to static and no batching in terms of serving smaller requests faster