diff --git a/README.md b/README.md index a2cbb217..a1454792 100644 --- a/README.md +++ b/README.md @@ -596,11 +596,7 @@ if __name__ == "__main__": | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | -| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters: -- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached. -- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM. - -Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` | +| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | ## API Server Implementation diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0a44187e..0eb1b27e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -87,7 +87,11 @@ class LightRAG: ) # Default not to use embedding cache embedding_cache_config: dict = field( - default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95} + default_factory=lambda: { + "enabled": False, + "similarity_threshold": 0.95, + "use_llm_check": False, + } ) kv_storage: str = field(default="JsonKVStorage") vector_storage: str = field(default="NanoVectorDBStorage") @@ -174,7 +178,6 @@ def __post_init__(self): if self.enable_llm_cache else None ) - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) @@ -481,6 +484,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): self.text_chunks, param, asdict(self), + hashing_kv=self.llm_response_cache, ) elif param.mode == "naive": response = await naive_query( @@ -489,6 +493,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): self.text_chunks, param, asdict(self), + hashing_kv=self.llm_response_cache, ) else: raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/llm.py b/lightrag/llm.py index 63913c90..b2bb99b7 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,8 +4,7 @@ import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any, Union, Optional -from dataclasses import dataclass +from typing import List, Dict, Callable, Any, Union import aioboto3 import aiohttp import numpy as np @@ -27,13 +26,9 @@ ) from transformers import AutoTokenizer, AutoModelForCausalLM -from .base import BaseKVStorage from .utils import ( - compute_args_hash, wrap_embedding_func_with_attrs, locate_json_string_body_from_string, - quantize_embedding, - get_best_cached_response, ) import sys @@ -66,23 +61,13 @@ async def openai_complete_if_cache( openai_async_client = ( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) - + kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( model=model, messages=messages, **kwargs @@ -95,21 +80,6 @@ async def openai_complete_if_cache( if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return content @@ -140,10 +110,7 @@ async def azure_openai_complete_if_cache( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) - - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - mode = kwargs.pop("mode", "default") - + kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -151,34 +118,11 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) - # Handle cache - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) content = response.choices[0].message.content - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return content @@ -210,7 +154,7 @@ async def bedrock_complete_if_cache( os.environ["AWS_SESSION_TOKEN"] = os.environ.get( "AWS_SESSION_TOKEN", aws_session_token ) - + kwargs.pop("hashing_kv", None) # Fix message history format messages = [] for history_message in history_messages: @@ -220,15 +164,6 @@ async def bedrock_complete_if_cache( # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response # Initialize Converse API arguments args = {"modelId": model, "messages": messages} @@ -251,15 +186,6 @@ async def bedrock_complete_if_cache( args["inferenceConfig"][inference_params_map.get(param, param)] = ( kwargs.pop(param) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response # Call model via Converse API session = aioboto3.Session() @@ -269,21 +195,6 @@ async def bedrock_complete_if_cache( except Exception as e: raise BedrockError(e) - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response["output"]["message"]["content"][0]["text"], - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response["output"]["message"]["content"][0]["text"] @@ -315,22 +226,12 @@ async def hf_model_if_cache( ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - + kwargs.pop("hashing_kv", None) input_prompt = "" try: input_prompt = hf_tokenizer.apply_chat_template( @@ -375,21 +276,6 @@ async def hf_model_if_cache( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response_text, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response_text @@ -410,25 +296,14 @@ async def ollama_model_if_cache( # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) - + kwargs.pop("hashing_kv", None) ollama_client = ollama.AsyncClient(host=host, timeout=timeout) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) - - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: """ cannot cache stream response """ @@ -439,40 +314,7 @@ async def inner(): return inner() else: - result = response["message"]["content"] - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=result, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return result - result = response["message"]["content"] - - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=result, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - - return result + return response["message"]["content"] @lru_cache(maxsize=1) @@ -547,7 +389,7 @@ async def lmdeploy_model_if_cache( from lmdeploy import version_info, GenerationConfig except Exception: raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") - + kwargs.pop("hashing_kv", None) kwargs.pop("response_format", None) max_new_tokens = kwargs.pop("max_tokens", 512) tp = kwargs.pop("tp", 1) @@ -579,19 +421,9 @@ async def lmdeploy_model_if_cache( if system_prompt: messages.append({"role": "system", "content": system_prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, @@ -607,22 +439,6 @@ async def lmdeploy_model_if_cache( session_id=1, ): response += res.response - - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response @@ -1052,75 +868,6 @@ async def llm_model_func( return await next_model.gen_func(**args) -async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): - """Generic cache handling function""" - if hashing_kv is None: - return None, None, None, None - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - - quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - mode=mode, - ) - if best_cached_response is not None: - return best_cached_response, None, None, None - else: - # Use regular cache - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None - - return None, quantized, min_val, max_val - - -@dataclass -class CacheData: - args_hash: str - content: str - model: str - prompt: str - quantized: Optional[np.ndarray] = None - min_val: Optional[float] = None - max_val: Optional[float] = None - mode: str = "default" - - -async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None: - return - - mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} - - mode_cache[cache_data.args_hash] = { - "return": cache_data.content, - "model": cache_data.model, - "embedding": cache_data.quantized.tobytes().hex() - if cache_data.quantized is not None - else None, - "embedding_shape": cache_data.quantized.shape - if cache_data.quantized is not None - else None, - "embedding_min": cache_data.min_val, - "embedding_max": cache_data.max_val, - "original_prompt": cache_data.prompt, - } - - await hashing_kv.upsert({cache_data.mode: mode_cache}) - - if __name__ == "__main__": import asyncio diff --git a/lightrag/operate.py b/lightrag/operate.py index acbdf072..feaec27d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -17,6 +17,10 @@ split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, + compute_args_hash, + handle_cache, + save_to_cache, + CacheData, ) from .base import ( BaseGraphStorage, @@ -452,8 +456,17 @@ async def kg_query( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: - context = None + # Handle cache + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): examples = "\n".join( @@ -471,12 +484,9 @@ async def kg_query( return PROMPTS["fail_response"] # LLM generate keywords - use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) - result = await use_model_func( - kw_prompt, keyword_extraction=True, mode=query_param.mode - ) + result = await use_model_func(kw_prompt, keyword_extraction=True) logger.info("kw_prompt result:") print(result) try: @@ -537,7 +547,6 @@ async def kg_query( query, system_prompt=sys_prompt, stream=query_param.stream, - mode=query_param.mode, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( @@ -550,6 +559,20 @@ async def kg_query( .strip() ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response @@ -1013,8 +1036,17 @@ async def naive_query( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, global_config: dict, + hashing_kv: BaseKVStorage = None, ): + # Handle cache use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] @@ -1039,7 +1071,6 @@ async def naive_query( response = await use_model_func( query, system_prompt=sys_prompt, - mode=query_param.mode, ) if len(response) > len(sys_prompt): @@ -1054,4 +1085,18 @@ async def naive_query( .strip() ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response diff --git a/lightrag/prompt.py b/lightrag/prompt.py index d758397b..863d38dc 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -261,3 +261,22 @@ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ + +PROMPTS[ + "similarity_check" +] = """Please analyze the similarity between these two questions: + +Question 1: {original_prompt} +Question 2: {cached_prompt} + +Please evaluate: +1. Whether these two questions are semantically similar +2. Whether the answer to Question 2 can be used to answer Question 1 + +Please provide a similarity score between 0 and 1, where: +0: Completely unrelated or answer cannot be reused +1: Identical and answer can be directly reused +0.5: Partially related and answer needs modification to be used + +Return only a number between 0-1, without any additional content. +""" diff --git a/lightrag/utils.py b/lightrag/utils.py index 4c8d7996..32d5c87f 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -9,12 +9,14 @@ from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List +from typing import Any, Union, List, Optional import xml.etree.ElementTree as ET import numpy as np import tiktoken +from lightrag.prompt import PROMPTS + ENCODER = None logger = logging.getLogger("lightrag") @@ -314,6 +316,9 @@ async def get_best_cached_response( current_embedding, similarity_threshold=0.95, mode="default", + use_llm_check=False, + llm_func=None, + original_prompt=None, ) -> Union[str, None]: # Get mode-specific cache mode_cache = await hashing_kv.get_by_id(mode) @@ -348,6 +353,37 @@ async def get_best_cached_response( best_cache_id = cache_id if best_similarity > similarity_threshold: + # If LLM check is enabled and all required parameters are provided + if use_llm_check and llm_func and original_prompt and best_prompt: + compare_prompt = PROMPTS["similarity_check"].format( + original_prompt=original_prompt, cached_prompt=best_prompt + ) + + try: + llm_result = await llm_func(compare_prompt) + llm_result = llm_result.strip() + llm_similarity = float(llm_result) + + # Replace vector similarity with LLM similarity score + best_similarity = llm_similarity + if best_similarity < similarity_threshold: + log_data = { + "event": "llm_check_cache_rejected", + "original_question": original_prompt[:100] + "..." + if len(original_prompt) > 100 + else original_prompt, + "cached_question": best_prompt[:100] + "..." + if len(best_prompt) > 100 + else best_prompt, + "similarity_score": round(best_similarity, 4), + "threshold": similarity_threshold, + } + logger.info(json.dumps(log_data, ensure_ascii=False)) + return None + except Exception as e: # Catch all possible exceptions + logger.warning(f"LLM similarity check failed: {e}") + return None # Return None directly when LLM check fails + prompt_display = ( best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt ) @@ -390,3 +426,84 @@ def dequantize_embedding( """Restore quantized embedding""" scale = (max_val - min_val) / (2**bits - 1) return (quantized * scale + min_val).astype(np.float32) + + +async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): + """Generic cache handling function""" + if hashing_kv is None: + return None, None, None, None + + # For naive mode, only use simple cache matching + if mode == "naive": + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + return None, None, None, None + + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", + {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + use_llm_check = embedding_cache_config.get("use_llm_check", False) + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + llm_model_func = hashing_kv.global_config.get("llm_model_func") + + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + use_llm_check=use_llm_check, + llm_func=llm_model_func if use_llm_check else None, + original_prompt=prompt if use_llm_check else None, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + # Use regular cache + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, quantized, min_val, max_val + + +@dataclass +class CacheData: + args_hash: str + content: str + prompt: str + quantized: Optional[np.ndarray] = None + min_val: Optional[float] = None + max_val: Optional[float] = None + mode: str = "default" + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + if hashing_kv is None: + return + + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + + mode_cache[cache_data.args_hash] = { + "return": cache_data.content, + "embedding": cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None, + "embedding_shape": cache_data.quantized.shape + if cache_data.quantized is not None + else None, + "embedding_min": cache_data.min_val, + "embedding_max": cache_data.max_val, + "original_prompt": cache_data.prompt, + } + + await hashing_kv.upsert({cache_data.mode: mode_cache})