Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FusedSDPA for prompt attention with VLLM_PROMPT_USE_FUSEDSDPA #168

Merged
merged 14 commits into from
Aug 19, 2024
30 changes: 20 additions & 10 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -158,6 +159,12 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_usefusedsdpa = os.getenv('VLLM_PREFILL_USE_FUSEDSDPA',
libinta marked this conversation as resolved.
Show resolved Hide resolved
'0') in ['1', 'true']
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HabanaPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -211,15 +218,18 @@ def forward(
if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads,
self.head_size)
Expand All @@ -232,7 +242,7 @@ def forward(
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
)
valid_seq_lengths=attn_metadata.seq_lens_tensor)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# prefix-enabled attention
Expand Down
59 changes: 46 additions & 13 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
import torch
import torch.nn.functional as F

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

import vllm.hpu.utils as hpu_utils

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')
Expand Down Expand Up @@ -123,6 +129,21 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
return final_hidden_states.view(-1, D)


#TODO: remove after SW-195415 fix
libinta marked this conversation as resolved.
Show resolved Hide resolved
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
if n_rep == 1:
return kv
kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


@hpu_utils.with_mark_steps
def prompt_attention(
query: torch.Tensor,
Expand All @@ -131,23 +152,35 @@ def prompt_attention(
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
valid_seq_lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_weights = torch.matmul(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
if attn_bias is not None or FusedSDPA is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_weights = torch.matmul(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
else:
#TODO: remove after SW-195415 fix
libinta marked this conversation as resolved.
Show resolved Hide resolved
if query_heads != kv_heads:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'left')
return attn_weights
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
logger = init_logger(__name__)
if is_hpu():
try:
from habana_frameworks.torch.hpex.normalization import (FusedRMSNorm as
HPUFusedRMSNorm
)
from habana_frameworks.torch.hpex.normalization import (
FusedRMSNorm as HPUFusedRMSNorm)
except ImportError:
logger.warning(
"Could not import HPU FusedRMSNorm kernel. "
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ class HpuModelAdapter():

def __init__(self, model, enforce_eager):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PREFILL_USE_FUSEDSDPA',
'0') in ['1', 'true']

if not htorch.utils.internal.is_lazy() and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
Expand All @@ -159,7 +162,7 @@ def __init__(self, model, enforce_eager):
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
prefill_metadata = attn_metadata
if prefill_metadata is None:
if prefill_metadata is None or self.prefill_use_fusedsdpa:
return attn_metadata

seq_lens_t = prefill_metadata.seq_lens_tensor
Expand Down