Skip to content

Commit

Permalink
[Bugfix] Avoid import AttentionMetadata explicitly in Mllama (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#10593)

Signed-off-by: Isotr0py <2037008807@qq.com>
  • Loading branch information
Isotr0py authored and weilong.yu committed Dec 13, 2024
1 parent 679fffe commit d099d22
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
5 changes: 5 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def __post_init__(self):

class BlocksparseFlashAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
# For attention layer compatibility
return "FLASH_ATTN"

@staticmethod
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
return BlocksparseFlashAttentionImpl
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -98,6 +98,7 @@ def __init__(
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
self.backend = backend_name_to_enum(attn_backend.get_name())

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
Expand Down Expand Up @@ -828,7 +827,8 @@ def _attention_with_mask(
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata):
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
Expand All @@ -842,7 +842,7 @@ def _attention_with_mask(
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
Expand All @@ -852,9 +852,9 @@ def _attention_with_mask(
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")
f"Unsupported Attention backend {self.attn.backend} "
"enum found. Expected the Attention backend to be "
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")

# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
Expand Down
8 changes: 6 additions & 2 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import TYPE_CHECKING

import openvino as ov
import openvino.properties.hint as hints
import torch

import vllm.envs as envs
Expand All @@ -16,6 +14,12 @@

logger = init_logger(__name__)

try:
import openvino as ov
import openvino.properties.hint as hints
except ImportError as e:
logger.warning("Failed to import OpenVINO with %r", e)


class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_supported_head_sizes() -> List[int]:

@staticmethod
def get_name() -> str:
return "flash-attn-vllm-v1"
return "FLASH_ATTN_VLLM_V1"

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
Expand Down

0 comments on commit d099d22

Please sign in to comment.