Skip to content

Commit

Permalink
Merge pull request #408 from magicyuan876/main
Browse files Browse the repository at this point in the history
修复 args_hash在使用常规缓存时候才计算导致embedding缓存时没有计算的bug
  • Loading branch information
LarFii authored Dec 6, 2024
2 parents 2a2756d + 7c4bbe2 commit 8a6796b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,11 @@ if __name__ == "__main__":
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.

Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |

## API Server Implementation

Expand Down
66 changes: 49 additions & 17 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import os
import base64
import copy
from functools import lru_cache
import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any

import aioboto3
import aiohttp
import numpy as np
import ollama

import torch
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)

import base64
import struct

from pydantic import BaseModel, Field
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any

from .base import BaseKVStorage
from .utils import (
compute_args_hash,
Expand Down Expand Up @@ -66,7 +65,11 @@ async def openai_complete_if_cache(
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -86,7 +89,6 @@ async def openai_complete_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -159,7 +161,11 @@ async def azure_openai_complete_if_cache(
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -178,7 +184,7 @@ async def azure_openai_complete_if_cache(
if best_cached_response is not None:
return best_cached_response
else:
args_hash = compute_args_hash(model, messages)
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -271,6 +277,9 @@ async def bedrock_complete_if_cache(

hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -290,7 +299,6 @@ async def bedrock_complete_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -343,6 +351,11 @@ def initialize_hf_model(model_name):
return hf_model, hf_tokenizer


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def hf_model_if_cache(
model,
prompt,
Expand All @@ -360,6 +373,9 @@ async def hf_model_if_cache(
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -379,7 +395,6 @@ async def hf_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -448,6 +463,11 @@ async def hf_model_if_cache(
return response_text


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def ollama_model_if_cache(
model,
prompt,
Expand All @@ -468,7 +488,11 @@ async def ollama_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -488,7 +512,6 @@ async def ollama_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -542,6 +565,11 @@ def initialize_lmdeploy_pipeline(
return lmdeploy_pipe


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def lmdeploy_model_if_cache(
model,
prompt,
Expand Down Expand Up @@ -620,7 +648,11 @@ async def lmdeploy_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -640,7 +672,6 @@ async def lmdeploy_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -831,7 +862,8 @@ async def openai_embedding(
)
async def nvidia_openai_embedding(
texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding
Expand Down

0 comments on commit 8a6796b

Please sign in to comment.