Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add concurrent embedding limit #441

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions lightrag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,16 @@ async def upsert(self, data: dict[str, dict]):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)

async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
Expand Down
21 changes: 20 additions & 1 deletion lightrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@

from lightrag.prompt import PROMPTS


class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""

async def __aenter__(self):
pass

async def __aexit__(self, exc_type, exc, tb):
pass


ENCODER = None

logger = logging.getLogger("lightrag")
Expand All @@ -42,9 +53,17 @@ class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
concurrent_limit: int = 16

def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()

async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
async with self._semaphore:
return await self.func(*args, **kwargs)


def locate_json_string_body_from_string(content: str) -> Union[str, None]:
Expand Down
Loading