Skip to content

Commit

Permalink
Feat/summarization pipeline (#26)
Browse files Browse the repository at this point in the history
* WIP

* Add summarization pipeline

* Fix formatting

* Update .env.example
  • Loading branch information
homanp authored Feb 10, 2024
1 parent b1eabc5 commit 1fd9bb4
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 19 deletions.
3 changes: 1 addition & 2 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
API_BASE_URL=https://rag.superagent.sh
COHERE_API_KEY=
HUGGINGFACE_API_KEY=
JWT_SECRET=
OPENAI_API_KEY=
29 changes: 22 additions & 7 deletions api/ingest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Dict

import requests
import aiohttp
from fastapi import APIRouter

from models.ingest import RequestPayload
Expand All @@ -17,12 +18,26 @@ async def ingest(payload: RequestPayload) -> Dict:
vector_credentials=payload.vector_database,
)
documents = await embedding_service.generate_documents()
chunks = await embedding_service.generate_chunks(documents=documents)
await embedding_service.generate_embeddings(nodes=chunks)
summary_documents = await embedding_service.generate_summary_documents(
documents=documents
)
chunks, summary_chunks = await asyncio.gather(
embedding_service.generate_chunks(documents=documents),
embedding_service.generate_chunks(documents=summary_documents),
)

await asyncio.gather(
embedding_service.generate_embeddings(nodes=chunks),
embedding_service.generate_embeddings(
nodes=summary_chunks, index_name=f"{payload.index_name}summary"
),
)

if payload.webhook_url:
requests.post(
url=payload.webhook_url,
json={"index_name": payload.index_name, "status": "completed"},
)
async with aiohttp.ClientSession() as session:
await session.post(
url=payload.webhook_url,
json={"index_name": payload.index_name, "status": "completed"},
)

return {"success": True, "index_name": payload.index_name}
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ colorlog==6.8.2
cryptography==41.0.7
dataclasses-json==0.6.3
Deprecated==1.2.14
dirtyjson==1.0.8
distro==1.9.0
dnspython==2.4.2
docx2txt==0.8
Expand Down Expand Up @@ -48,7 +49,7 @@ idna==3.6
importlib-metadata==6.11.0
Jinja2==3.1.3
joblib==1.3.2
llama-index==0.9.30
llama-index==0.9.46
loguru==0.7.2
lxml==5.1.0
MarkupSafe==2.1.3
Expand Down
29 changes: 22 additions & 7 deletions service/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import copy
from tempfile import NamedTemporaryFile
from typing import Any, List, Union
from typing import Any, List, Optional, Union

import numpy as np
import requests
Expand All @@ -11,6 +12,7 @@

from models.file import File
from service.vector_database import get_vector_service
from utils.summarise import completion


class EmbeddingService:
Expand All @@ -37,9 +39,9 @@ async def generate_documents(self) -> List[Document]:
for file in tqdm(self.files, desc="Generating documents"):
suffix = self._get_datasource_suffix(file.type.value)
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
response = requests.get(url=file.url)
temp_file.write(response.content)
temp_file.flush()
with requests.get(url=file.url) as response: # Add context manager here
temp_file.write(response.content)
temp_file.flush()
reader = SimpleDirectoryReader(input_files=[temp_file.name])
docs = reader.load_data()
for doc in docs:
Expand All @@ -55,8 +57,7 @@ async def generate_chunks(
return nodes

async def generate_embeddings(
self,
nodes: List[Union[Document, None]],
self, nodes: List[Union[Document, None]], index_name: Optional[str] = None
) -> List[tuple[str, list, dict[str, Any]]]:
pbar = tqdm(total=len(nodes), desc="Generating embeddings")

Expand All @@ -81,8 +82,22 @@ async def generate_embedding(node):
embeddings = await asyncio.gather(*tasks)
pbar.close()
vector_service = get_vector_service(
index_name=self.index_name, credentials=self.vector_credentials
index_name=index_name or self.index_name,
credentials=self.vector_credentials,
)
await vector_service.upsert(embeddings=[e for e in embeddings if e is not None])

return [e for e in embeddings if e is not None]

async def generate_summary_documents(
self, documents: List[Document]
) -> List[Document]:
pbar = tqdm(total=len(documents), desc="Summarizing documents")
summary_documents = []
for document in documents:
doc_copy = copy.deepcopy(document) # Make a copy of the document
doc_copy.text = await completion(document=doc_copy)
summary_documents.append(doc_copy)
pbar.update()
pbar.close()
return summary_documents
9 changes: 7 additions & 2 deletions service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def create_route_layer() -> RouteLayer:
score_threshold=0.5,
)
]
print(config("COHERE_API_KEY"))
encoder = CohereEncoder(cohere_api_key=config("COHERE_API_KEY"))
return RouteLayer(encoder=encoder, routes=routes)

Expand All @@ -40,9 +41,13 @@ async def get_documents(vector_service: VectorService, payload: RequestPayload)
async def query(payload: RequestPayload) -> List:
rl = create_route_layer()
decision = rl(payload.input).name
print(decision)

if decision == "summarize":
return []
vector_service: VectorService = get_vector_service(
index_name=f"{payload.index_name}summary",
credentials=payload.vector_database,
)
return await get_documents(vector_service, payload)

vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
Expand Down
Empty file added utils/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions utils/summarise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from decouple import config
from llama_index import Document
from openai import AsyncOpenAI

client = AsyncOpenAI(
api_key=config("OPENAI_API_KEY"),
)


def _generate_content(document: Document) -> str:
return f"""Make an in depth summary the block of text below.
Text:
------------------------------------------
{document.get_content()}
------------------------------------------
Your summary:"""


async def completion(document: Document):
content = _generate_content(document)
completion = await client.chat.completions.create(
messages=[
{
"role": "user",
"content": content,
}
],
model="gpt-3.5-turbo-16k",
)

return completion.choices[0].message.content

0 comments on commit 1fd9bb4

Please sign in to comment.