From d099d22785012bc7297b46f69995f818bdfb4963 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 24 Nov 2024 02:12:20 +0800 Subject: [PATCH] [Bugfix] Avoid import AttentionMetadata explicitly in Mllama (#10593) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/attention/backends/blocksparse_attn.py | 5 +++++ vllm/attention/layer.py | 3 ++- vllm/model_executor/models/mllama.py | 14 +++++++------- vllm/platforms/openvino.py | 8 ++++++-- vllm/v1/attention/backends/flash_attn.py | 2 +- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 94002e36db2bb..9e54c3b40c54e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -87,6 +87,11 @@ def __post_init__(self): class BlocksparseFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + # For attention layer compatibility + return "FLASH_ATTN" + @staticmethod def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: return BlocksparseFlashAttentionImpl diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index cb4dedf481c77..1bb335909484b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -6,7 +6,7 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( @@ -98,6 +98,7 @@ def __init__( self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) + self.backend = backend_name_to_enum(attn_backend.get_name()) # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 41f62b37f3bd9..9e6634a9a7579 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -32,9 +32,8 @@ import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.backends.xformers import XFormersMetadata from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, @@ -828,7 +827,8 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: - if isinstance(attn_metadata, FlashAttentionMetadata): + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) torch.ops._C_cache_ops.reshape_and_cache_flash( @@ -842,7 +842,7 @@ def _attention_with_mask( 1.0, 1.0, ) - elif isinstance(attn_metadata, XFormersMetadata): + elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -852,9 +852,9 @@ def _attention_with_mask( attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) else: raise ValueError( - f"Unsupported AttentionMetadata {type(attn_metadata)} " - f"class found. Expected the AttentionMetadata to " - f"be either XFormersMetadata or FlashAttentionMetadata.") + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 91e615481ff8e..ea5ec7b40b95c 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,7 +1,5 @@ from typing import TYPE_CHECKING -import openvino as ov -import openvino.properties.hint as hints import torch import vllm.envs as envs @@ -16,6 +14,12 @@ logger = init_logger(__name__) +try: + import openvino as ov + import openvino.properties.hint as hints +except ImportError as e: + logger.warning("Failed to import OpenVINO with %r", e) + class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d98bb5a716e97..5f8535eaa303f 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -19,7 +19,7 @@ def get_supported_head_sizes() -> List[int]: @staticmethod def get_name() -> str: - return "flash-attn-vllm-v1" + return "FLASH_ATTN_VLLM_V1" @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: