diff --git a/.gitignore b/.gitignore index c3c8240..03935e6 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,13 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + +# Vscode +.vscode/ + +# bash +evaluation/modified_evaluate.py +evaluation/A100_evaluate.sh +evaluation/4090_evaluate_ada.sh +evaluation/logs/* +evaluation/results/* diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 2f1e040..2bd083c 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -12,6 +12,7 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.tova_press import TOVAPress +from kvpress.presses.ada_snapkv_press import AdaSnapKVPress __all__ = [ "BasePress", @@ -24,4 +25,5 @@ "TOVAPress", "KVPressTextGenerationPipeline", "apply_per_layer_compression", + "AdaSnapKVPress" ] diff --git a/kvpress/ada_attn.py b/kvpress/ada_attn.py new file mode 100644 index 0000000..7fffb5b --- /dev/null +++ b/kvpress/ada_attn.py @@ -0,0 +1,247 @@ + +# Copyright (c) 2024 YuanFeng +# +# This file is part of the YuanFeng project and is licensed under the MIT License. +# SPDX-License-Identifier: MIT + +from attr import dataclass +from transformers.utils import is_flash_attn_greater_or_equal_2_10 +from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb, +) + + +import logging +from typing import Optional, Tuple + +import torch +from transformers import Cache + +from kvpress.ada_cache import DynamicCacheSplitHeadFlatten + +logger = logging.getLogger(__name__) + + +from transformers.utils import ( + logging, + is_flash_attn_2_available, +) +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + +@dataclass +class MetaData: + decoding_cu_seqlens_q: torch.Tensor + cu_seqlens_k: torch.Tensor + max_seqlen_k: int + cu_offset: torch.Tensor + cu_head_offset: torch.Tensor + head_lens: torch.Tensor + bsz: int + num_key_value_heads: int + seen_tokens: int + + + + +class AdaLlamaFlashAttention(LlamaAttention): + + """ + Llama flash attention module for AdaKV. This module inherits from `LlamaAttention` as the weights of the module stays untouched. + Utilizing the flash_attn_varlen_func from the flash_attn library to perform the attention operation with flattened KV Cache layout. + """ + + # update the metadata for the flatten cache during the decoding phase + def _update_metadata_while_compressing(self, head_lens, cu_seqlens_k,max_seqlen_k): + self.metadata.head_lens = head_lens + self.metadata.cu_seqlens_k = cu_seqlens_k + self.metadata.max_seqlen_k = max_seqlen_k + + + def _update_metadata(self, key_states): + bs, head, seqlen, dim = key_states.shape + + self.metadata.max_seqlen_k += seqlen + self.metadata.cu_seqlens_k += self.metadata.cu_offset * seqlen + self.metadata.head_lens += seqlen + self.metadata.seen_tokens += seqlen + + # init the metadata for the flattened cache during the prefilling phase + def _init_metadata(self, key_states): + + """ + this method is used to initialize metadata for the flatten cache, + input key_states is a regular key states with shape [bsz, num_key_value_heads, seqlen, head_dim] + """ + + bsz, num_key_value_heads, k_len, head_dim = key_states.shape + k_lens = bsz * num_key_value_heads * [k_len] + _device = key_states.device + max_seqlen_k = max(k_lens) + + head_seqlens_k = torch.tensor(k_lens, dtype=torch.int32, device=_device) + cu_seqlens = torch.cumsum(head_seqlens_k, dim=0, dtype=torch.int32) + cu_seqlens_k = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=_device), cu_seqlens], dim=0) + + + decoding_q_lens = bsz * num_key_value_heads * [1] + decoding_head_seqlens_q = torch.tensor(decoding_q_lens, dtype=torch.int32,device=_device) + decoding_cu_seqlens_q = torch.cumsum(decoding_head_seqlens_q, dim=0, dtype=torch.int32) + decoding_cu_seqlens_q = torch.cat( + [ torch.tensor([0], dtype=torch.int32, device=_device), decoding_cu_seqlens_q], dim=0) + + + cu_offset = torch.arange(0, bsz * num_key_value_heads + 1, dtype=torch.int32, device=_device) + cu_head_offset = torch.arange(1, bsz * num_key_value_heads + 1, dtype=torch.int32, device=_device) + + self.metadata = MetaData( + decoding_cu_seqlens_q = decoding_cu_seqlens_q, + cu_seqlens_k = cu_seqlens_k, + max_seqlen_k = max_seqlen_k, + cu_offset = cu_offset, + cu_head_offset = cu_head_offset, + head_lens = head_seqlens_k, + bsz = bsz, + num_key_value_heads = num_key_value_heads, + seen_tokens= k_len + ) + + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # used to store the metadata for the flatten cache + self.metadata = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if not isinstance(past_key_value, DynamicCacheSplitHeadFlatten): + raise ValueError( + "current implementation of `AdaKV` only supports `DynamicCacheSplitHeadFlatten` as the cache type." + ) + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "attn": self} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) + + # dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + + + query_states = query_states.view(bsz,-1, self.num_key_value_groups,q_len ,self.head_dim) + + query_states = query_states.transpose(2, 3) + query_states = query_states.reshape(-1,self.num_key_value_groups,self.head_dim) + + + key_states = key_states.view(-1,1,self.head_dim) + value_states = value_states.view(-1,1,self.head_dim) + + + if q_len == 1: + # init metadata for flatten query states during prefilling phase + cu_seqlens_q = self.metadata.decoding_cu_seqlens_q + max_seqlen_q = 1 + else: + # init metadata for flatten query states during prefilling phase + prefill_q_lens = bsz * self.num_heads//self.num_key_value_groups * [q_len] + head_seqlens_q = torch.tensor(prefill_q_lens, dtype=torch.int32, device=query_states.device) + cu_seqlens_q = torch.cumsum(head_seqlens_q, dim=0, dtype=torch.int32) + cu_seqlens_q = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=query_states.device), cu_seqlens_q], dim=0) + max_seqlen_q = q_len + + cu_seqlens_k = self.metadata.cu_seqlens_k + max_seqlen_k = self.metadata.max_seqlen_k + + + attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q, + cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True) + # TODO: support batch size > 1 + assert bsz == 1 + + attn_output = attn_output.reshape(bsz, self.num_key_value_heads, q_len, self.num_key_value_groups, self.head_dim) + attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + diff --git a/kvpress/ada_cache.py b/kvpress/ada_cache.py new file mode 100644 index 0000000..6e96911 --- /dev/null +++ b/kvpress/ada_cache.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 YuanFeng +# +# This file is part of the YuanFeng project and is licensed under the MIT License. +# SPDX-License-Identifier: MIT + +from transformers.cache_utils import Cache +from typing import List, Optional, Tuple +import torch + + +class DynamicCacheSplitHeadFlatten(Cache): + + """ + Flattened KV Cache Layout with a costomized update kernel + """ + + def __init__(self) ->None: + super().__init__() + self.key_cache: List[List[torch.Tensor]] = [] + self.value_cache: List[List[torch.Tensor]] = [] + self._seen_tokens = 0 + + def __len__(self): + return len(self.key_cache) + + def __iter__(self): + for layer_idx in range(len(self)): + yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) + + def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]: + if layer_idx < len(self): + return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def update(self, key_states, value_states, layer_idx, cache_kwargs=None): + # NOTE: k, v = [head_num]( bs, 1, seqlen, dim) + # each layer is a flatten layout like: + # [bsz * (head_0_len + head_1_len + ...+ head_n_len) , dim] + attn = cache_kwargs.get("attn", None) + if len(self.key_cache) <= layer_idx: + # prefilling + # flatten key and value + bs, head_num, seqlen, head_dim = key_states.shape + flatten_key_cachee = key_states.reshape(bs* head_num* seqlen, head_dim) + flatten_value_cache = value_states.reshape(bs* head_num* seqlen, head_dim) + self.key_cache.append(flatten_key_cachee) + self.value_cache.append(flatten_value_cache) + + # init metadata for flatten key states + attn._init_metadata(key_states) + self._seen_tokens = attn.metadata.seen_tokens + else: + # decoding + assert self.key_cache[layer_idx].dim() == 2 + bs, head, seqlen, head_dim = key_states.shape + + # TODO: Currently only support bs == 1 + assert bs == 1 , f"bs: {bs}" + # NOTE: phase 2. we got [bs, head, seqlen, dim] as k, v input + head_lens = attn.metadata.head_lens + cu_seqlens_k = attn.metadata.cu_seqlens_k + + # TODO: wrap as a python interface + from tiny_api_cuda import update_flatten_klenN_view + new_key_cache = update_flatten_klenN_view(self.key_cache[layer_idx].view(-1, head_dim), key_states, head_lens, cu_seqlens_k) + new_value_cache = update_flatten_klenN_view(self.value_cache[layer_idx].view(-1, head_dim), value_states, head_lens, cu_seqlens_k) + + self.key_cache[layer_idx] = new_key_cache + self.value_cache[layer_idx] = new_value_cache + + # update metadata + attn._update_metadata(key_states) + self._seen_tokens = attn.metadata.seen_tokens + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if len(self.key_cache) <= layer_idx: + return 0 + + # TODO: return 1 to means has content for now + return 1 + + def get_max_length(self) -> Optional[int]: + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx],),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCacheEachHead": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + cache = cls() + print(f"from_legacy_cache past_key_values") + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.key_cache.append(key_states) + cache.value_cache.append(value_states) + + # TODO seen tokens should be updated + cache._seen_tokens = None + return cache \ No newline at end of file diff --git a/kvpress/csrc.zip b/kvpress/csrc.zip new file mode 100644 index 0000000..f4572a5 Binary files /dev/null and b/kvpress/csrc.zip differ diff --git a/kvpress/csrc/LICENSE b/kvpress/csrc/LICENSE new file mode 100644 index 0000000..5c65dde --- /dev/null +++ b/kvpress/csrc/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 66RING + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/kvpress/csrc/build.py b/kvpress/csrc/build.py new file mode 100644 index 0000000..8f8c086 --- /dev/null +++ b/kvpress/csrc/build.py @@ -0,0 +1,109 @@ +import subprocess +import os +from packaging.version import parse, Version +from pathlib import Path +from setuptools import setup, find_packages +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + +# package name managed by pip, which can be remove by `pip uninstall tiny_pkg` +PACKAGE_NAME = "tiny_pkg" + +ext_modules = [] +generator_flag = [] +cc_flag = [] +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") + + +# helper function to get cuda version +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +# if CUDA_HOME is not None: +# _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +# if bare_metal_version >= Version("11.8"): +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_90,code=sm_90") + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +# cuda module +ext_modules.append( + CUDAExtension( + # package name for import + name="tiny_api_cuda", + sources=[ + "csrc/cuda_api.cu", + ], + extra_compile_args={ + # add c compile flags + "cxx": ["-O3", "-std=c++17"] + generator_flag, + # add nvcc compile flags + "nvcc": [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=-v", + "--ptxas-options=-O2", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + generator_flag + + cc_flag, + }, + include_dirs=[ + Path(this_dir) / "csrc", + Path(this_dir) / "include", + # Path(this_dir) / "some" / "thing" / "more", + ], + ) +) + +setup( + name=PACKAGE_NAME, + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + "tiny_pkg.egg-info", + ) + ), + description="Tiny cuda and c api binding for pytorch.", + ext_modules=ext_modules, + cmdclass={ "build_ext": BuildExtension}, + python_requires=">=3.7", + install_requires=[ + "torch", + "einops", + "packaging", + "ninja", + ], +) + + + + diff --git a/kvpress/csrc/csrc/cuda_api.cu b/kvpress/csrc/csrc/cuda_api.cu new file mode 100644 index 0000000..b2b17d4 --- /dev/null +++ b/kvpress/csrc/csrc/cuda_api.cu @@ -0,0 +1,217 @@ +#include +#include +#include +#include +#include +#include + +#include "cuda_api.h" +#include "static_switch.h" + +template +__global__ void update_flatten_view_klenN_kernel(tensor_t *dst_ptr, tensor_t *src_ptr, + tensor_t *state_ptr, int *headlens, + int *cu_headlens, int dim, + int new_klen) { + // NOTE(66ring): + // cache.shape = (total_len * total_head, head_dim) + // states.shape = (bz, head_num, k_len, dim) + + int head_idx = blockIdx.x; + int thread_group = blockIdx.y; + int tid = threadIdx.x + thread_group * blockDim.x; + int num_threads = blockDim.x * gridDim.y; + + int headlen = headlens[head_idx]; + + // get position of src, dst, insert ptr + int src_cum_off = cu_headlens[head_idx] * dim; + int dst_cum_off = src_cum_off + head_idx * new_klen * dim; + + + auto old_cache_ptr = src_ptr + src_cum_off; + auto new_cache_ptr = dst_ptr + dst_cum_off; + + // copy old data + for (int start_addr = 0; start_addr < headlen * dim; start_addr += kblock_size * num_threads) { + auto src_addr = old_cache_ptr + start_addr + tid * kblock_size; + auto dst_addr = new_cache_ptr + start_addr + tid * kblock_size; + + // TODO: LDSM speed up with SRAM + #pragma unroll + for (int i = 0; i < kblock_size; i++) { + if (start_addr + tid * kblock_size + i >= headlen * dim) { + break; + } + dst_addr[i] = src_addr[i]; + } + } + + // insert new data + int insert_off = (cu_headlens[head_idx + 1] + head_idx * new_klen) * dim; + auto insert_cache_ptr = dst_ptr + insert_off; + for (int start_addr = 0; start_addr < new_klen * dim; start_addr += kblock_size * num_threads) { + auto src_addr = (state_ptr + head_idx * new_klen * dim) + start_addr + tid * kblock_size; + auto dst_addr = insert_cache_ptr + start_addr + tid * kblock_size; + + // TODO: LDSM speed up with SRAM + #pragma unroll + for (int i = 0; i < kblock_size; i++) { + if (start_addr + tid * kblock_size + i >= new_klen * dim) { + break; + } + dst_addr[i] = src_addr[i]; + } + } + +} + +template +__global__ void update_flatten_view_kernel(tensor_t *dst_ptr, tensor_t *src_ptr, + tensor_t *state_ptr, int *headlens, + int *cu_headlens, int dim) { + // Create new tensor from cache and insert element into it. + + int head_idx = blockIdx.x; + int thread_group = blockIdx.y; + int tid = threadIdx.x + thread_group * blockDim.x; + int num_threads = blockDim.x * gridDim.y; + + int headlen = headlens[head_idx]; + + // get position of src, dst, insert ptr + int src_cum_off = cu_headlens[head_idx] * dim; + int dst_cum_off = src_cum_off + head_idx * dim; + + auto old_cache_ptr = src_ptr + src_cum_off; + auto new_cache_ptr = dst_ptr + dst_cum_off; + + // copy old data + for (int start_addr = 0; start_addr < headlen * dim; + start_addr += kblock_size * num_threads) { + auto src_addr = old_cache_ptr + start_addr + tid * kblock_size; + auto dst_addr = new_cache_ptr + start_addr + tid * kblock_size; + +// TODO: LDSM speed up with SRAM +#pragma unroll + for (int i = 0; i < kblock_size; i++) { + if (start_addr + tid * kblock_size + i >= headlen * dim) { + break; + } + dst_addr[i] = src_addr[i]; + } + } + + // insert new data + int new_klen = 1; + int insert_off = (cu_headlens[head_idx + 1] + head_idx * new_klen) * dim; + auto insert_cache_ptr = dst_ptr + insert_off; + for (int start_addr = 0; start_addr < new_klen * dim; start_addr += kblock_size * num_threads) { + auto src_addr = (state_ptr + head_idx * new_klen * dim) + start_addr + tid * kblock_size; + auto dst_addr = insert_cache_ptr + start_addr + tid * kblock_size; + + // TODO: LDSM speed up with SRAM + #pragma unroll + for (int i = 0; i < kblock_size; i++) { + if (start_addr + tid * kblock_size + i >= new_klen * dim) { + break; + } + dst_addr[i] = src_addr[i]; + } + } +} + +torch::Tensor update_flatten_view(torch::Tensor &cache, torch::Tensor &state, + torch::Tensor &headlens, + torch::Tensor &cu_headlens) { + TORCH_CHECK(headlens.dtype() == torch::kInt32, "expected headlens to be int32"); + TORCH_CHECK(cu_headlens.dtype() == torch::kInt32, "expected cu_dst_pos to be int32"); + + auto cache_shape = cache.sizes(); + + int origin_len = cache_shape[0]; + int head_dim = cache_shape[1]; + int head_num = headlens.sizes()[0]; + + torch::Tensor out = torch::empty({origin_len + head_num, head_dim}, cache.options()); + + const int kblock_size = 1; + const int num_threads_group = 1024; + const int num_threads = 256; + + dim3 grid(head_num, num_threads_group); + + // TODO: dispatch with head_dim?? may loss performance + dim3 block(num_threads); + + FP16_SWITCH(cache.dtype() == torch::kFloat16, [&] { + auto kernel = update_flatten_view_kernel; + kernel<<>>( + (elem_type *)out.data_ptr(), (elem_type *)cache.data_ptr(), + (elem_type *)state.data_ptr(), (int *)headlens.data_ptr(), + (int *)cu_headlens.data_ptr(), head_dim); + }); + + // TODO: when to use sync or torch auto + // cudaDeviceSynchronize(); + + return out; +} + +torch::Tensor update_flatten_klenN_view(torch::Tensor &cache, + torch::Tensor &state, + torch::Tensor &headlens, + torch::Tensor &cu_headlens) { + // NOTE(66ring): + // cache.shape = (total_len * total_head, head_dim) + // states.shape = (bz, head_num, k_len, dim) + + TORCH_CHECK(headlens.dtype() == torch::kInt32, + "expected headlens to be int32"); + TORCH_CHECK(cu_headlens.dtype() == torch::kInt32, + "expected cu_dst_pos to be int32"); + + + cache = cache.contiguous(); + state = state.contiguous(); + auto cache_shape = cache.sizes(); + auto state_shape = state.sizes(); + + int origin_len = cache_shape[0]; + int new_klen = state_shape[2]; + int new_flatlen = state_shape[0] * state_shape[1] * state_shape[2]; + int head_dim = cache_shape[1]; + int head_num = headlens.sizes()[0]; + + torch::Tensor out = + torch::empty({origin_len + new_flatlen, head_dim}, cache.options()); + + const int kblock_size = 1; + const int num_threads_group = 1024; + const int num_threads = 256; + + dim3 grid(head_num, num_threads_group); + + // TODO: dispatch with head_dim?? may loss performance + dim3 block(num_threads); + + FP16_SWITCH(cache.dtype() == torch::kFloat16, [&] { + auto kernel = update_flatten_view_klenN_kernel; + kernel<<>>( + (elem_type *)out.data_ptr(), (elem_type *)cache.data_ptr(), + (elem_type *)state.data_ptr(), (int *)headlens.data_ptr(), + (int *)cu_headlens.data_ptr(), head_dim, new_klen); + }); + + // TODO: when to use sync or torch auto + // cudaDeviceSynchronize(); + + return out; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // m.def("package_name", &function_name, "function_docstring"") + + m.def("update_flatten_view", &update_flatten_view, "update flatten view cache"); + m.def("update_flatten_klenN_view", &update_flatten_klenN_view, "update flatten view cache"); +} diff --git a/kvpress/csrc/csrc/static_switch.h b/kvpress/csrc/csrc/static_switch.h new file mode 100644 index 0000000..6b454d7 --- /dev/null +++ b/kvpress/csrc/csrc/static_switch.h @@ -0,0 +1,14 @@ +#pragma once + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = at::Half; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + }() + + diff --git a/kvpress/csrc/include/cuda_api.h b/kvpress/csrc/include/cuda_api.h new file mode 100644 index 0000000..d02d4dc --- /dev/null +++ b/kvpress/csrc/include/cuda_api.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#define DEBUG 1 + +#ifdef DEBUG + +// NOTE:tensor malloc as device before we call +// e.g. data.to("cuda") in python +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CUDA_ERROR_CHECK(condition) \ + do { \ + cudaError_t error = condition; \ + if (error != cudaSuccess) { \ + printf("CUDA_CHECK error in line %d of file %s \ + : %s \n", \ + __LINE__, __FILE__, cudaGetErrorString(error)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#else + +#define CHECK_CUDA(x) do { } while (0) +#define CHECK_CONTIGUOUS(x) do { } while (0) +#define CHECK_INPUT(x) do { } while (0) +#define CUDA_ERROR_CHECK(condition) do { condition; } while (0) + +#endif // DEBUG + + diff --git a/kvpress/csrc/makefile b/kvpress/csrc/makefile new file mode 100644 index 0000000..4e1f0fa --- /dev/null +++ b/kvpress/csrc/makefile @@ -0,0 +1,4 @@ +build: + python build.py install + +.PHONY: build diff --git a/kvpress/csrc/test.py b/kvpress/csrc/test.py new file mode 100644 index 0000000..4d56867 --- /dev/null +++ b/kvpress/csrc/test.py @@ -0,0 +1,92 @@ +import torch +import random +from tiny_api_cuda import update_flatten_view, update_flatten_klenN_view + + +def test_single_insert(head_num, head_dim): + head_lens = [] + for _ in range(head_num): + head_lens.append(random.randint(1, 100)) + head_lens = torch.tensor(head_lens, dtype=torch.int32, device='cuda') + klen_sum = torch.sum(head_lens, dtype=torch.int32) + cu_klen = torch.cumsum(head_lens, 0, dtype=torch.int32) - head_lens + cu_klen = torch.cat([cu_klen, torch.tensor([klen_sum], dtype=torch.int32, device="cuda")], dim=0) + + head_cache_list = [] + for hl in head_lens: + head_cache = torch.randn((hl, head_dim), dtype=torch.bfloat16, device="cuda") + head_cache_list.append(head_cache) + head_cache_tensor = torch.cat(head_cache_list, dim=0) + key_state0 = torch.randn((1, head_num, 1, head_dim), dtype=torch.bfloat16, device="cuda") + + expected_cache = head_cache_list.copy() + for i in range(head_num): + expected_cache[i] = torch.cat([expected_cache[i], key_state0[0, i, 0, :].view(1, head_dim)], dim=0) + expected_cache = torch.cat(expected_cache, dim=0) + ref_new_state_0 = update_flatten_view(head_cache_tensor.view(-1, head_dim), key_state0.view(-1, head_dim), head_lens, cu_klen) + + assert torch.equal(expected_cache, ref_new_state_0) + +def test_single_insertN(head_num, head_dim, klen): + head_lens = [] + seqlen = 3810 + # for _ in range(head_num): + # head_lens.append(random.randint(1, 100)) + head_lens = [seqlen] * head_num + head_lens = torch.tensor(head_lens, dtype=torch.int32, device='cuda') + klen_sum = torch.sum(head_lens, dtype=torch.int32) + cu_klen = torch.cumsum(head_lens, 0, dtype=torch.int32) - head_lens + cu_klen = torch.cat([cu_klen, torch.tensor([klen_sum], dtype=torch.int32, device="cuda")], dim=0) + # print("cu_klen", cu_klen) + + # cu_klen = torch.cumsum(head_lens, 0, dtype=torch.int32) + # cu_klen = torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), cu_klen, ], dim=0) + # print("cu_klen", cu_klen) + # input("Press Enter to continue...") + # head_cache_list = [] + # for hl in head_lens: + # head_cache = torch.randn((hl, head_dim), dtype=torch.bfloat16, device="cuda") + # head_cache_list.append(head_cache) + # head_cache_tensor = torch.cat(head_cache_list, dim=0) + key_state0 = torch.randn((1, head_num, klen, head_dim), dtype=torch.bfloat16, device="cuda") + head_cache = torch.randn((1, head_num, seqlen, head_dim), dtype=torch.bfloat16, device="cuda") + + expected_cache = torch.cat([head_cache, key_state0], dim=2) + print("expected_cache.shape", expected_cache.shape) + expected_cache = expected_cache.view(-1, head_dim) + print("expected_cache.shape", expected_cache.shape) + + head_cache = head_cache.view(-1, head_dim) + # expected_cache = head_cache_list.copy() + # for i in range(head_num): + # expected_cache[i] = torch.cat([expected_cache[i], key_state0[0, i, :, :].view(-1, head_dim)], dim=0) + # expected_cache = torch.cat(expected_cache, dim=0) + + + print("head_cache_tensor.shape", head_cache.shape) + print("key_state0.shape", key_state0.shape) + print("head_lens", head_lens) + print("cu_klen", cu_klen) + ref_new_state_0 = update_flatten_klenN_view(head_cache, key_state0, head_lens, cu_klen) + print("ref_new_state_0.shape", ref_new_state_0.shape) + + assert torch.equal(expected_cache, ref_new_state_0) + input(f"{head_num, head_dim, klen}Test passed") + +def main(seed): + random.seed(seed) + torch.manual_seed(seed) + + for head_num in [8]: + for head_dim in [128]: + + # test_single_insert(head_num, head_dim) + + for klen in [1, 2, 128, 256, 512]: + + test_single_insertN(head_num, head_dim, klen) + + +if __name__ == "__main__": + for seed in range(100): + main(seed) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 1ae45ae..46a6fe1 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -2,21 +2,25 @@ # SPDX-License-Identifier: Apache-2.0 +from calendar import c import contextlib import logging from typing import Optional import torch -from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline +from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline, StaticCache from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor -from kvpress.presses.base_press import BasePress +from kvpress.ada_cache import DynamicCacheSplitHeadFlatten +from kvpress.presses.base_press import BasePress, AdaBasePress from kvpress.presses.observed_attention_press import ObservedAttentionPress logger = logging.getLogger(__name__) + + class KVPressTextGenerationPipeline(Pipeline): """ Pipeline for key-value compression in causal language models. @@ -66,7 +70,6 @@ def _sanitize_parameters( - forward_kwargs: The keyword arguments for the forward function. - postprocess_kwargs: The keyword arguments for the postprocess function. """ - answer_prefix = answer_prefix or "" postprocess_kwargs = {"single_question": questions is None} assert question is None or questions is None, "Either question or questions should be provided, not both." @@ -113,7 +116,7 @@ def preprocess( # Add question_suffix and answer prefix # e.g. for llama3.1, question_suffix="<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n") questions = [question + question_suffix + answer_prefix for question in questions] - + # Tokenize the context and questions context_ids = self.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False) question_ids = [ @@ -161,7 +164,11 @@ def _forward( # Prefilling using the press on the context if cache is None: - cache = DynamicCache() + # check if the press is an case of AdaKV + if isinstance(press, AdaBasePress): + cache = DynamicCacheSplitHeadFlatten() + else: + cache = DynamicCache() with press(self.model) if press is not None else contextlib.nullcontext(): self.model( @@ -183,6 +190,9 @@ def _forward( context_length=context_length, max_new_tokens=max_new_tokens, ) + # print(answer) + # input('one answer') + answers.append(answer) return answers @@ -247,14 +257,14 @@ def generate_answer( break answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) - # Remove the generated tokens from the cache - if isinstance(cache, QuantizedCache): - key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache" - else: - key_attr, value_attr = "key_cache", "value_cache" + # # Remove the generated tokens from the cache + # if isinstance(cache, QuantizedCache): + # key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache" + # else: + # key_attr, value_attr = "key_cache", "value_cache" - setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)]) - setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)]) + # setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)]) + # setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)]) return answer diff --git a/kvpress/presses/ada_snapkv_press.py b/kvpress/presses/ada_snapkv_press.py new file mode 100644 index 0000000..abf0aae --- /dev/null +++ b/kvpress/presses/ada_snapkv_press.py @@ -0,0 +1,118 @@ +# Author: Yuan Feng +# Corresponding Paper: Ada-KV + + +import inspect +import math +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.models.llama.modeling_llama import repeat_kv, rotate_half +import transformers.modeling_utils as modeling_utils +from kvpress.presses.base_press import AdaBasePress + + +@dataclass +class AdaSnapKVPress(AdaBasePress): + """ + SnapKV (https://arxiv.org/abs/2404.14469) use the attention of the latest window_size tokens to estimate the + importance of the previous KV pairs. We use the default settings from: + https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py#L24 + """ + + compression_ratio: float = 0.0 + window_size: int = 64 + kernel_size: int = 5 + floor_alpha: float = 0.2 + + def compute_window_attention(self, module, hidden_states, keys): + """ + Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. + """ + + bsz, q_len, _ = hidden_states.shape + + # Get last window_size queries + if hasattr(module, "q_proj"): + query_states = module.q_proj(hidden_states[:, -self.window_size :]) + elif hasattr(module, "qkv_proj"): + qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) + query_states = qkv[..., : module.num_heads * module.head_dim] + else: + raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") + + query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) + + # Apply RoPE + if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: + position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) + cos, sin = module.rotary_emb(query_states, position_ids) + else: + cos, sin = module.rotary_emb(query_states, q_len) + cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0) + query_states = (query_states * cos) + (rotate_half(query_states) * sin) + + # Compute attention for first q_len - window_size tokens + key_states = repeat_kv(keys, module.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + attention_mask = torch.ones_like(attn_weights) * float("-inf") + attention_mask = torch.triu(attention_mask, diagonal=q_len - self.window_size + 1) + attn_weights += attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = attn_weights[..., : -self.window_size] + + return attn_weights + + + """ + using mask to identify the KV Selection, with the selected KV pairs with MAX mask value + """ + def score( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs, + ) -> torch.Tensor: + cache_metadata = module.metadata + + # Current implementation only allows to compress once + # check if first time compression + head_lens = cache_metadata.head_lens + assert all(x == head_lens[0] for x in head_lens), "Not all elements in head_lens are the same, implying multiple compressions" + + + # convert to (bsz, num_key_value_heads, q_len, head_dim) for easy score + keys = keys.view(cache_metadata.bsz, cache_metadata.num_key_value_heads, cache_metadata.head_lens[0], keys.shape[-1]) + values = values.view(cache_metadata.bsz, cache_metadata.num_key_value_heads, cache_metadata.head_lens[0], keys.shape[-1]) + + + bsz, num_key_value_heads, q_len, _ = keys.shape + + assert q_len > self.window_size, "Query length should be greater than the window size" + + if attentions is not None: + attn_weights = attentions[..., -self.window_size :, : -self.window_size] + else: + attn_weights = self.compute_window_attention(module, hidden_states, keys) + + scores = attn_weights.mean(dim=-2) + scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) + + # Average per grioup (https://github.com/FasterDecoding/SnapKV/issues/22) + scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len - self.window_size) + scores = scores.mean(2) + + # safe guard for each head AdaKV + compress_q_len = q_len * (1 - self.compression_ratio) * self.floor_alpha + topk_idx = scores.topk(int(compress_q_len), dim=-1).indices + scores.scatter_(-1, topk_idx, torch.finfo(scores.dtype).max) + + # Add back the observation window. Use max score to make sure the window is not pruned. + scores = F.pad(scores, (0, self.window_size), value=scores.max().item()) + + return scores diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 553187a..8cd5283 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -2,10 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 +from bz2 import compress +import getpass import logging from contextlib import contextmanager from typing import Generator +from unittest.mock import Base +from numpy import dtype import torch from torch import nn from transformers import ( @@ -17,8 +21,8 @@ QuantizedCache, ) -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class BasePress: """Base class for pruning methods. @@ -150,3 +154,128 @@ def __call__(self, model: PreTrainedModel) -> Generator: finally: for forward_hook in hooks: forward_hook.remove() + + + +class AdaBasePress(BasePress): + """Base class for pruning methods with Ada-KV Paramdigm. + Each pruning method should implement a `score` method that computes the scores for each KV pair in a layer. + This score is used to prune the KV pairs with the lowest scores in the `hook` method + The `hook` method is called after the forward pass of a layer and updates the cache with the pruned KV pairs. + The press can be applied to a model by calling it with the model as an argument. + """ + + def __init__(self, compression_ratio: float = 0.0): + self.compression_ratio = compression_ratio + assert 0 <= compression_ratio < 1, "Compression ratio must be between 0 and 1" + + + + # rewrite the forward_hook method for BasePress class to implement the AdaKV paradigm + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """Cache compression hook called after the forward pass of a decoder layer. + The hook is applied only during the pre-filling phase if there is some pruning ratio. + The current implementation only allows to remove a constant number of KV pairs. + + Parameters + ---------- + module : + Transformer attention layer. + input : + Input to the hook. This is the input to the forward pass of the layer. + kwargs : + Keyword arguments, as given to the forward pass of the layer. + output : + Output of the hook. This is the original output of the forward pass of the layer. + + Returns + ------- + Modified output of the forward pass of the layer. + + """ + # See e.g. LlamaDecoderLayer.forward for the output structure + if len(output) == 3: + _, attentions, cache = output + else: + attentions, cache = None, output[-1] + + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Don't compress if the compression ratio is 0 or this is not pre-filling + if (self.compression_ratio == 0) or (cache.seen_tokens > q_len): + return output + + if isinstance(cache, QuantizedCache): + keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) + values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) + else: + keys = cache.key_cache[module.layer_idx] + values = cache.value_cache[module.layer_idx] + + with torch.no_grad(): + scores = self.score(module, hidden_states, keys, values, attentions, kwargs) + + cache_metadata = module.metadata + num_key_value_heads = cache_metadata.num_key_value_heads + # Prune KV pairs with the lowest scores + n_kept = int(q_len * (1 - self.compression_ratio)) * num_key_value_heads + + # AdaKV paradigm + # TODO: current implementation only support bsz 1 + flatten_scores = scores.view( -1) + cache_topk_idx = flatten_scores.topk(n_kept, dim=-1).indices + head_len = cache_metadata.head_lens[0] + cache_topk_head_idx = cache_topk_idx // head_len + + compressed_head_lens = torch.zeros(num_key_value_heads, dtype=torch.int32,device=keys.device) + compressed_head_lens.scatter_add_(0, cache_topk_head_idx, torch.ones_like(cache_topk_head_idx, dtype=torch.int32)) + compressed_cu_seqlens_k = torch.cumsum(compressed_head_lens, dim=0, dtype=torch.int32) + # print(f"compressed_cu_seqlens_k: {compressed_cu_seqlens_k.dtype}") + + compressed_cu_seqlens_k = torch.cat([torch.tensor([0],dtype=torch.int32,device=keys.device), compressed_cu_seqlens_k]) + # print(f"compressed_cu_seqlens_k: {compressed_cu_seqlens_k.dtype}") + + compressed_max_seqlen_k = compressed_head_lens.max().cpu().item() + module._update_metadata_while_compressing(compressed_head_lens,compressed_cu_seqlens_k,compressed_max_seqlen_k) + + # sort the cache topk idx, cluster the retained cache in each head + sorted_4_cache_topk_idx = torch.argsort(cache_topk_head_idx,descending=False) + cache_topk_idx = cache_topk_idx[sorted_4_cache_topk_idx] + cache_topk_idx = cache_topk_idx.unsqueeze(-1).expand(-1,module.head_dim) + keys = keys.gather(0, cache_topk_idx).contiguous() + values = values.gather(0, cache_topk_idx).contiguous() + + if isinstance(cache, QuantizedCache): + cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) + cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + else: + cache.key_cache[module.layer_idx] = keys + cache.value_cache[module.layer_idx] = values + return output + + @contextmanager + def __call__(self, model: PreTrainedModel) -> Generator: + """ + Context manager to apply a compression method to a model. + Apply this context manager during the pre-filling phase to compress the context. + + Parameters + ---------- + model : PreTrainedModel + Model to apply the compression method to + """ + + if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)): + logger.warning(f"Model {type(model)} not tested") + + try: + hooks = [] + for layer in model.model.layers: + hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) + + yield + finally: + for forward_hook in hooks: + forward_hook.remove() +