Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize batch inference for text generation #586

Merged
merged 1 commit into from
Mar 30, 2023

Conversation

siddvenk
Copy link
Contributor

Description

TLDR: Implements actual batch inference for text generation use cases

In our handlers we currently use the transformers pipeline interface to handle tokenization and generation. This abstraction is slow. When passing in a batch of inputs, the inputs get executed sequentially (the run_multi method is just a for loop of single forward passes) https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/base.py#L1085-L1112.

This PR changes the generation implementation only for text generation tasks to use the tokenizer encode/decode and model.generate methods directly. Doing so achieves true batch processing. This is compatible with both accelerate and deepspeed.

Follow up:

  • I need to see whether we can apply this same implementation for other tasks. In particular, the pipeline abstraction for other tasks comes with pre and post process methods that run before and after tokenization. Need to explore that a bit more

Tests:
For the following tests, I executed requests with batch sizes 1, 2, 4, 8 with both the pipeline implementation, and the no pipeline implementation. Model used is bigscience/bloom3b. I have also tested with gpt2 and opt2.7b.

HuggingFace Accelerate, with pipeline

Time for inference with batch size 1 is 25.356090545654297s
Time for inference with batch size 2 is 50.49015522003174s
Time for inference with batch size 4 is 99.71887421607971s
Time for inference with batch size 8 is 197.7882523536682s

HuggingFace Accelerate, no pipeline

Time for inference with batch size 1 is 14.039892673492432s
Time for inference with batch size 2 is 13.926620721817017s
Time for inference with batch size 4 is 13.960245132446289s
Time for inference with batch size 8 is 13.707027196884155s

DeepSpeed, with pipeline

Time for inference with batch size 1 is 8.60791540145874s
Time for inference with batch size 2 is 16.066924810409546s
Time for inference with batch size 4 is 32.61290001869202s
Time for inference with batch size 8 is 65.23944902420044s

DeepSpeed, no pipeline

Time for inference with batch size 1 is 9.049405574798584s
Time for inference with batch size 2 is 8.418354988098145s
Time for inference with batch size 4 is 8.980300903320312s
Time for inference with batch size 8 is 10.209055423736572s

@siddvenk siddvenk requested review from zachgk, frankfliu and a team as code owners March 29, 2023 23:54
@siddvenk siddvenk merged commit 77bf5ad into deepjavalibrary:master Mar 30, 2023
@siddvenk siddvenk deleted the python-handler-pipeline branch March 30, 2023 00:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants