Skip to content

flashinfer-ai/flashinfer

Repository files navigation

FlashInfer

Kernel Library for LLM Serving

| Blog | Documentation | Slack| Discussion Forum |

Release Documentation

FlashInfer is a library for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, SparseAttention, PageAttention, Sampling, and more. FlashInfer focuses on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios.

The unique features of FlashInfer include:

  1. Comprehensive Attention Kernels: Attention kernels that cover all the common use cases of LLM serving, including single-request and batching versions of Prefill, Decode, and Append kernels, on different formats of KV-Cache (Padded Tensor, Ragged Tensor, and Page Table).
  2. Optimized Shared-Prefix Batch Decoding: FlashInfer enhances shared-prefix batch decoding performance through cascading, resulting in an impressive up to 31x speedup compared to the baseline vLLM PageAttention implementation (for long prompt of 32768 tokens and large batch size of 256).
  3. Accelerate Attention for Compressed/Quantized KV-Cache: Modern LLMs are often deployed with quantized/compressed KV-Cache to reduce memory traffic. FlashInfer accelerates these scenarios by optimizing performance for Grouped-Query Attention, Fused-RoPE Attention and Quantized Attention.

FlashInfer support PyTorch, TVM and C++ (header-only) APIs, and can be easily integrated into existing projects.

News

  • [Sept 2024] We've launched a Slack workspace for Flashinfer users and developers. Join us for timely support, discussions, updates and knowledge sharing!
  • [Jan 31, 2024] Blog Post Cascade Inference: Memory-Efficient Shared Prefix Batch Decoding
  • [Jan 31, 2024] Blog Post Accelerating Self-Attentions for LLM Serving with FlashInfer

Getting Started

Using our PyTorch API is the easiest way to get started:

Installation

We provide prebuilt wheels for Linux and you can try out FlashInfer with the following command:

# For CUDA 12.4 & torch 2.4
pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4
# For other CUDA & torch versions, please check https://docs.flashinfer.ai/installation.html

or you can build from source:

git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
cd flashinfer/python
pip install -e .

to reduce binary size during build and testing:

git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
cd flashinfer/python
# ref https://pytorch.org/docs/stable/generated/torch.cuda.get_device_capability.html#torch.cuda.get_device_capability
export TORCH_CUDA_ARCH_LIST=8.0
pip install -e .

Trying it out

Below is a minimal example of using FlashInfer's single-request decode/append/prefill attention kernels:

import torch
import flashinfer

kv_len = 2048
num_kv_heads = 32
head_dim = 128

k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)

# decode attention

num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
q = torch.randn(qo_len, num_qo_heads, head_dim).half().to(0) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly, do not apply causal mask

Check out documentation for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels.

Run Benchmarks

We profile FlashInfer kernel performance with nvbench and you can compile and run the benchmarks with the following commands:

mkdir build
cp cmake/config.cmake build # you can modify the config.cmake to enable/disable benchmarks and change CUDA architectures
cd build
cmake ..
make -j12

You can run ./bench_{single/batch}_{prefill/decode} to benchmark the performance (e.g. ./bench_single_prefill for single-request prefill attention). ./bench_{single/batch}_{prefill/decode} --help will show you the available options.

C++ API and TVM Bindings

FlashInfer also provides C++ API and TVM bindings, please refer to documentation for more details.

Adoption

Currently FlashInfer is adopted by the following projects:

Acknowledgement

FlashInfer is inspired by FlashAttention 1&2, vLLM, stream-K and cutlass projects.