Skip to content

Commit

Permalink
keep tqdm_async work
Browse files Browse the repository at this point in the history
  • Loading branch information
billvsme committed Dec 13, 2024
1 parent cf0278c commit a788c78
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
20 changes: 10 additions & 10 deletions lightrag/kg/milvus_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ async def upsert(self, data: dict[str, dict]):
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
for f in tqdm_async(
await asyncio.gather(*embedding_tasks),
total=len(embedding_tasks),
desc="Generating embeddings",
unit="batch",
):
embeddings = await f
embeddings_list.append(embeddings)

async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result

embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
embeddings_list = await asyncio.gather(*embedding_tasks)

embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
Expand Down
20 changes: 10 additions & 10 deletions lightrag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ async def upsert(self, data: dict[str, dict]):
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
for f in tqdm_async(
await asyncio.gather(*embedding_tasks),
total=len(embedding_tasks),
desc="Generating embeddings",
unit="batch",
):
embeddings = await f
embeddings_list.append(embeddings)

async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result

embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
embeddings_list = await asyncio.gather(*embedding_tasks)

embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
Expand Down

0 comments on commit a788c78

Please sign in to comment.