Skip to content

Commit

Permalink
feat: Added encoders to query
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Feb 11, 2024
1 parent 2146449 commit e48afc3
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 43 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SuperRag

Super-performant RAG pipeline for AI Agents/Assistants.
Super-performant RAG pipeline for AI Agents/Assistants.

## API

Expand All @@ -23,6 +23,7 @@ Input example:
}
},
"index_name": "my_index",
"encoder": "my_encoder"
"webhook_url": "https://my-webhook-url"
}
```
Expand All @@ -41,6 +42,7 @@ Input example:
}
},
"index_name": "my_index",
"encoder": "my_encoder",
}
```

Expand Down
17 changes: 3 additions & 14 deletions api/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import aiohttp
from fastapi import APIRouter

import encoders
from models.ingest import EncoderEnum, RequestPayload
from service.embedding import EmbeddingService
from models.ingest import RequestPayload
from service.embedding import EmbeddingService, get_encoder

router = APIRouter()

Expand All @@ -21,17 +20,7 @@ async def ingest(payload: RequestPayload) -> Dict:
documents = await embedding_service.generate_documents()
chunks = await embedding_service.generate_chunks(documents=documents)

encoder_mapping = {
EncoderEnum.cohere: encoders.CohereEncoder,
EncoderEnum.openai: encoders.OpenAIEncoder,
EncoderEnum.huggingface: encoders.HuggingFaceEncoder,
EncoderEnum.fastembed: encoders.FastEmbedEncoder,
}

encoder_class = encoder_mapping.get(payload.encoder)
if encoder_class is None:
raise ValueError(f"Unsupported encoder: {payload.encoder}")
encoder = encoder_class()
encoder = get_encoder(encoder_type=payload.encoder)

summary_documents = await embedding_service.generate_summary_documents(
documents=documents
Expand Down
60 changes: 60 additions & 0 deletions dev/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,66 @@
"\n",
"print(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Query the index\n",
"query_url = f\"{API_URL}/api/v1/query\"\n",
"\n",
"query_payload = {\n",
" \"input\": \"What is the best chunk strategy?\",\n",
" \"vector_database\": {\n",
" \"type\": \"pinecone\",\n",
" \"config\": {\n",
" \"api_key\": PINECONE_API_KEY,\n",
" \"host\": PINECONE_HOST,\n",
" }\n",
" },\n",
" \"index_name\": PINECONE_INDEX,\n",
" \"encoder\": \"openai\",\n",
"}\n",
"\n",
"query_response = requests.post(query_url, json=query_payload)\n",
"\n",
"print(query_response.json())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Delete the index\n",
"query_url = f\"{API_URL}/api/v1/delete\"\n",
"\n",
"delete_payload = {\n",
" \"file_url\": \"https://arxiv.org/pdf/2402.05131.pdf\",\n",
" \"vector_database\": {\n",
" \"type\": \"pinecone\",\n",
" \"config\": {\n",
" \"api_key\": PINECONE_API_KEY,\n",
" \"host\": PINECONE_HOST,\n",
" }\n",
" },\n",
" \"index_name\": PINECONE_INDEX,\n",
"}\n",
"\n",
"delete_response = requests.delete(query_url, json=delete_payload)\n",
"\n",
"print(delete_response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
1 change: 1 addition & 0 deletions encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class BaseEncoder(BaseModel):
name: str
score_threshold: float
type: str = Field(default="base")
dimension: int = Field(default=1536)

class Config:
arbitrary_types_allowed = True
Expand Down
8 changes: 6 additions & 2 deletions encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
from typing import List, Optional

import openai
from dotenv import load_dotenv
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse
from semantic_router.utils.logger import logger

from encoders import BaseEncoder
from semantic_router.utils.logger import logger

load_dotenv()


class OpenAIEncoder(BaseEncoder):
client: Optional[openai.Client]
type: str = "openai"
dimension: int = 1536

def __init__(
self,
Expand All @@ -21,7 +25,7 @@ def __init__(
score_threshold: float = 0.82,
):
if name is None:
name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002")
name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-3-small")
super().__init__(name=name, score_threshold=score_threshold)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
if api_key is None:
Expand Down
2 changes: 2 additions & 0 deletions models/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from pydantic import BaseModel
from models.ingest import EncoderEnum

from models.vector_database import VectorDatabase

Expand All @@ -9,6 +10,7 @@ class RequestPayload(BaseModel):
input: str
vector_database: VectorDatabase
index_name: str
encoder: EncoderEnum = EncoderEnum.openai


class ResponseData(BaseModel):
Expand Down
17 changes: 17 additions & 0 deletions service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from tqdm import tqdm

from encoders import BaseEncoder
import encoders
from models.file import File
from models.ingest import EncoderEnum
from service.vector_database import get_vector_service
from utils.summarise import completion

Expand Down Expand Up @@ -85,6 +87,7 @@ async def generate_embedding(node):
vector_service = get_vector_service(
index_name=index_name or self.index_name,
credentials=self.vector_credentials,
encoder=encoder,
)
await vector_service.upsert(embeddings=[e for e in embeddings if e is not None])

Expand All @@ -102,3 +105,17 @@ async def generate_summary_documents(
pbar.update()
pbar.close()
return summary_documents


def get_encoder(*, encoder_type: EncoderEnum) -> encoders.BaseEncoder:
encoder_mapping = {
EncoderEnum.cohere: encoders.CohereEncoder,
EncoderEnum.openai: encoders.OpenAIEncoder,
EncoderEnum.huggingface: encoders.HuggingFaceEncoder,
EncoderEnum.fastembed: encoders.FastEmbedEncoder,
}

encoder_class = encoder_mapping.get(encoder_type)
if encoder_class is None:
raise ValueError(f"Unsupported encoder: {encoder_type}")
return encoder_class()
15 changes: 11 additions & 4 deletions service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from semantic_router.route import Route

from models.query import RequestPayload
from service.embedding import get_encoder
from service.vector_database import VectorService, get_vector_service


Expand All @@ -27,7 +28,9 @@ def create_route_layer() -> RouteLayer:
return RouteLayer(encoder=encoder, routes=routes)


async def get_documents(vector_service: VectorService, payload: RequestPayload) -> List:
async def get_documents(
*, vector_service: VectorService, payload: RequestPayload
) -> List:
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_rerank_format(chunks=chunks)

Expand All @@ -41,15 +44,19 @@ async def get_documents(vector_service: VectorService, payload: RequestPayload)
async def query(payload: RequestPayload) -> List:
rl = create_route_layer()
decision = rl(payload.input).name
encoder = get_encoder(encoder_type=payload.encoder)

if decision == "summarize":
vector_service: VectorService = get_vector_service(
index_name=f"{payload.index_name}summary",
credentials=payload.vector_database,
encoder=encoder,
)
return await get_documents(vector_service, payload)
return await get_documents(vector_service=vector_service, payload=payload)

vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
index_name=payload.index_name,
credentials=payload.vector_database,
encoder=encoder,
)
return await get_documents(vector_service, payload)
return await get_documents(vector_service=vector_service, payload=payload)
Loading

0 comments on commit e48afc3

Please sign in to comment.