pip install -e . && pip install flash_attn==2.5.8 --no-build-isolation
from criti_prefill.modeling_patch import replace_llama_eattention, criti_config
model = LlamaForCausalLM.from_pretrained(args.model_name_or_path,
device_map=device,
torch_dtype=dtype,
attn_implementation="flash_attention_2"
)
criti_config(model,
segment_size=args.segment_size,
threshold_len=args.threshold_len,
block_size=args.block_size,
budgets=args.budgets,
layer_fusion=args.layer_fusion,
layer_skip=args.layer_skip)
Time to first token(TTFT) is one of the most intuitive metrics for user experience, yet it tends to be significantly slower compared to decoding time.
CritiPrefill can significantly reduce the TTFT while maintaining generation quality.