diff --git a/README.md b/README.md index 145871ee..00612859 100644 --- a/README.md +++ b/README.md @@ -596,6 +596,7 @@ if __name__ == "__main__": | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | +| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` | ## API Server Implementation diff --git a/examples/lightrag_openai_compatible_demo_embedding_cache.py b/examples/lightrag_openai_compatible_demo_embedding_cache.py new file mode 100644 index 00000000..69106d05 --- /dev/null +++ b/examples/lightrag_openai_compatible_demo_embedding_cache.py @@ -0,0 +1,112 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "solar-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar", + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model="solar-embedding-1-large-query", + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar", + ) + + +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + return embedding_dim + + +# function test +async def test_funcs(): + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("embedding_func: ", result) + + +# asyncio.run(test_funcs()) + + +async def main(): + try: + embedding_dimension = await get_embedding_dim() + print(f"Detected embedding dimension: {embedding_dimension}") + + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_cache_config={ + "enabled": True, + "similarity_threshold": 0.90, + }, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), + ) + + with open("./book.txt", "r", encoding="utf-8") as f: + await rag.ainsert(f.read()) + + # Perform naive search + print( + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + # Perform local search + print( + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + # Perform global search + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="global"), + ) + ) + + # Perform hybrid search + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid"), + ) + ) + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a25dab79..0a44187e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -85,7 +85,10 @@ class LightRAG: working_dir: str = field( default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) - + # Default not to use embedding cache + embedding_cache_config: dict = field( + default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95} + ) kv_storage: str = field(default="JsonKVStorage") vector_storage: str = field(default="NanoVectorDBStorage") graph_storage: str = field(default="NetworkXStorage") diff --git a/lightrag/llm.py b/lightrag/llm.py index 2810d93e..33fdd182 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -33,6 +33,8 @@ compute_args_hash, wrap_embedding_func_with_attrs, locate_json_string_body_from_string, + quantize_embedding, + get_best_cached_response, ) os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -65,10 +67,29 @@ async def openai_complete_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + ) + if best_cached_response is not None: + return best_cached_response + else: + # Use regular cache + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( @@ -81,10 +102,24 @@ async def openai_complete_if_cache( content = response.choices[0].message.content if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - # print(content) + if hashing_kv is not None: await hashing_kv.upsert( - {args_hash: {"return": response.choices[0].message.content, "model": model}} + { + args_hash: { + "return": content, + "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val if is_embedding_cache_enabled else None, + "embedding_max": max_val if is_embedding_cache_enabled else None, + "original_prompt": prompt, + } + } ) return content @@ -125,10 +160,28 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + ) + if best_cached_response is not None: + return best_cached_response + else: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs @@ -136,7 +189,21 @@ async def azure_openai_complete_if_cache( if hashing_kv is not None: await hashing_kv.upsert( - {args_hash: {"return": response.choices[0].message.content, "model": model}} + { + args_hash: { + "return": response.choices[0].message.content, + "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val if is_embedding_cache_enabled else None, + "embedding_max": max_val if is_embedding_cache_enabled else None, + "original_prompt": prompt, + } + } ) return response.choices[0].message.content @@ -204,10 +271,29 @@ async def bedrock_complete_if_cache( hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + ) + if best_cached_response is not None: + return best_cached_response + else: + # Use regular cache + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] # Call model via Converse API session = aioboto3.Session() @@ -223,6 +309,19 @@ async def bedrock_complete_if_cache( args_hash: { "return": response["output"]["message"]["content"][0]["text"], "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val + if is_embedding_cache_enabled + else None, + "embedding_max": max_val + if is_embedding_cache_enabled + else None, + "original_prompt": prompt, } } ) @@ -245,7 +344,11 @@ def initialize_hf_model(model_name): async def hf_model_if_cache( - model, prompt, system_prompt=None, history_messages=[], **kwargs + model, + prompt, + system_prompt=None, + history_messages=[], + **kwargs, ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) @@ -257,10 +360,30 @@ async def hf_model_if_cache( messages.append({"role": "user", "content": prompt}) if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + ) + if best_cached_response is not None: + return best_cached_response + else: + # Use regular cache + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + input_prompt = "" try: input_prompt = hf_tokenizer.apply_chat_template( @@ -305,12 +428,32 @@ async def hf_model_if_cache( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) if hashing_kv is not None: - await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) + await hashing_kv.upsert( + { + args_hash: { + "return": response_text, + "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val if is_embedding_cache_enabled else None, + "embedding_max": max_val if is_embedding_cache_enabled else None, + "original_prompt": prompt, + } + } + ) return response_text async def ollama_model_if_cache( - model, prompt, system_prompt=None, history_messages=[], **kwargs + model, + prompt, + system_prompt=None, + history_messages=[], + **kwargs, ) -> str: kwargs.pop("max_tokens", None) # kwargs.pop("response_format", None) # allow json @@ -326,18 +469,52 @@ async def ollama_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + ) + if best_cached_response is not None: + return best_cached_response + else: + # Use regular cache + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] response = await ollama_client.chat(model=model, messages=messages, **kwargs) result = response["message"]["content"] if hashing_kv is not None: - await hashing_kv.upsert({args_hash: {"return": result, "model": model}}) - + await hashing_kv.upsert( + { + args_hash: { + "return": result, + "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val if is_embedding_cache_enabled else None, + "embedding_max": max_val if is_embedding_cache_enabled else None, + "original_prompt": prompt, + } + } + ) return result @@ -444,10 +621,29 @@ async def lmdeploy_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + ) + if best_cached_response is not None: + return best_cached_response + else: + # Use regular cache + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, @@ -466,7 +662,23 @@ async def lmdeploy_model_if_cache( response += res.response if hashing_kv is not None: - await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) + await hashing_kv.upsert( + { + args_hash: { + "return": response, + "model": model, + "embedding": quantized.tobytes().hex() + if is_embedding_cache_enabled + else None, + "embedding_shape": quantized.shape + if is_embedding_cache_enabled + else None, + "embedding_min": min_val if is_embedding_cache_enabled else None, + "embedding_max": max_val if is_embedding_cache_enabled else None, + "original_prompt": prompt, + } + } + ) return response diff --git a/lightrag/utils.py b/lightrag/utils.py index 8997b651..d080ee03 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -307,3 +307,72 @@ def process_combine_contexts(hl, ll): combined_sources_result = "\n".join(combined_sources_result) return combined_sources_result + + +async def get_best_cached_response( + hashing_kv, current_embedding, similarity_threshold=0.95 +): + """Get the cached response with the highest similarity""" + try: + # Get all keys + all_keys = await hashing_kv.all_keys() + max_similarity = 0 + best_cached_response = None + + # Get cached data one by one + for key in all_keys: + cache_data = await hashing_kv.get_by_id(key) + if cache_data is None or "embedding" not in cache_data: + continue + + # Convert cached embedding list to ndarray + cached_quantized = np.frombuffer( + bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 + ).reshape(cache_data["embedding_shape"]) + cached_embedding = dequantize_embedding( + cached_quantized, + cache_data["embedding_min"], + cache_data["embedding_max"], + ) + + similarity = cosine_similarity(current_embedding, cached_embedding) + if similarity > max_similarity: + max_similarity = similarity + best_cached_response = cache_data["return"] + + if max_similarity > similarity_threshold: + return best_cached_response + return None + + except Exception as e: + logger.warning(f"Error in get_best_cached_response: {e}") + return None + + +def cosine_similarity(v1, v2): + """Calculate cosine similarity between two vectors""" + dot_product = np.dot(v1, v2) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + return dot_product / (norm1 * norm2) + + +def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple: + """Quantize embedding to specified bits""" + # Calculate min/max values for reconstruction + min_val = embedding.min() + max_val = embedding.max() + + # Quantize to 0-255 range + scale = (2**bits - 1) / (max_val - min_val) + quantized = np.round((embedding - min_val) * scale).astype(np.uint8) + + return quantized, min_val, max_val + + +def dequantize_embedding( + quantized: np.ndarray, min_val: float, max_val: float, bits=8 +) -> np.ndarray: + """Restore quantized embedding""" + scale = (max_val - min_val) / (2**bits - 1) + return (quantized * scale + min_val).astype(np.float32)