Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate embeddings for unstructured data #8

Merged
merged 3 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
DATABASE_URL=
DATABASE_MIGRATION_URL=
DATABASE_MIGRATION_URL=
OPENAI_API_KEY=
HF_API_KEY=
PINECONE_API_KEY=
11 changes: 10 additions & 1 deletion lib/api/ingest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from fastapi import APIRouter

from lib.service.embedding import EmbeddingService
from lib.utils.prisma import prisma

router = APIRouter()


Expand All @@ -8,6 +11,12 @@
name="ingest",
description="Ingest data",
)
async def ingest():
async def ingest(body: dict):
"""Endpoint for ingesting data"""
datasource = await prisma.datasource.create(data={**body})
embedding_service = EmbeddingService(datasource=datasource)
documents = embedding_service.generate_documents()
nodes = embedding_service.generate_chunks(documents=documents)
embeddings = embedding_service.generate_embeddings(nodes=nodes)
print(embeddings)
return {"success": True, "data": None}
65 changes: 65 additions & 0 deletions lib/service/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from tempfile import NamedTemporaryFile
from typing import List, Union

import requests
from decouple import config
from llama_index import Document, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from numpy import ndarray
from sentence_transformers import SentenceTransformer

from lib.service.vectordb import get_vector_service
from prisma.models import Datasource


class EmbeddingService:
def __init__(self, datasource: Datasource):
self.datasource = datasource

def get_datasource_suffix(self) -> str:
suffixes = {"TXT": ".txt", "PDF": ".pdf", "MARKDOWN": ".md"}
try:
return suffixes[self.datasource.type]
except KeyError:
raise ValueError("Unsupported datasource type")

def generate_documents(self) -> List[Document]:
with NamedTemporaryFile(
suffix=self.get_datasource_suffix(), delete=True
) as temp_file:
if self.datasource.url:
content = requests.get(self.datasource.url).content
else:
content = self.datasource.content
temp_file.write(content)
temp_file.flush()
reader = SimpleDirectoryReader(input_files=[temp_file.name])
docs = reader.load_data()
return docs

def generate_chunks(self, documents: List[Document]) -> List[Union[Document, None]]:
parser = SimpleNodeParser.from_defaults(chunk_size=350, chunk_overlap=20)
nodes = parser.get_nodes_from_documents(documents, show_progress=True)
return nodes

# def generate_qa_pairs(self, nodes: List[Union[Document, None]]) -> Dict[str, Any]:
# qa_pairs = generate_qa_embedding_pairs(nodes=nodes)
# return qa_pairs

def generate_embeddings(self, nodes: List[Union[Document, None]]) -> List[ndarray]:
vectordb = get_vector_service(
provider="pinecone",
index_name="all-minilm-l6-v2",
namespace=self.datasource.id,
dimension=384,
)
model = SentenceTransformer(
"all-MiniLM-L6-v2", use_auth_token=config("HF_API_KEY")
)
embeddings = []
for node in nodes:
if node is not None:
embedding = (node.id_, model.encode(node.text).tolist(), node.metadata)
embeddings.append(embedding)
vectordb.upsert(vectors=embeddings)
return embeddings
59 changes: 59 additions & 0 deletions lib/service/vectordb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from typing import List

import pinecone
from decouple import config
from numpy import ndarray


class VectorDBService(ABC):
def __init__(self, index_name: str, dimension: int, namespace: str = None):
self.index_name = index_name
self.namespace = namespace
self.dimension = dimension

@abstractmethod
def upsert():
pass

@abstractmethod
def query():
pass


class PineconeVectorService(VectorDBService):
def __init__(self, index_name: str, dimension: int, namespace: str = None):
super().__init__(
index_name=index_name, dimension=dimension, namespace=namespace
)
pinecone.init(api_key=config("PINECONE_API_KEY"))
# Create a new vector index if it doesn't
# exist dimensions should be passed in the arguments
if index_name not in pinecone.list_indexes():
pinecone.create_index(
name=index_name, metric="cosine", shards=1, dimension=dimension
)
self.index = pinecone.Index(index_name=self.index_name)

def __del__(self):
pinecone.deinit()

def upsert(self, vectors: ndarray):
self.index.upsert(vectors=vectors, namespace=self.namespace)

def query(self, queries: List[ndarray], top_k: int):
return self.index.query(queries=queries, top_k=top_k)


def get_vector_service(
provider: str, index_name: str, namespace: str = None, dimension: int = 384
):
services = {
"pinecone": PineconeVectorService,
# Add other providers here
# "weaviate": WeaviateVectorService,
}
service = services.get(provider)
if service is None:
raise ValueError(f"Unsupported provider: {provider}")
return service(index_name=index_name, namespace=namespace, dimension=dimension)
3 changes: 3 additions & 0 deletions lib/utils/prisma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from prisma import Prisma

prisma = Prisma()
Loading