Skip to content

Commit

Permalink
Clean up embedding API
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp committed Oct 22, 2023
1 parent 3361e80 commit 529240b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
4 changes: 2 additions & 2 deletions nagato/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


def create_vector_embeddings(
type: str, finetune_id: str, url: str = None, content: str = None
type: str, filter_id: str, url: str = None, content: str = None
) -> List[Union[Document, None]]:
embedding_service = EmbeddingService(type=type, content=content, url=url)
documents = embedding_service.generate_documents()
nodes = embedding_service.generate_chunks(documents=documents)
embedding_service.generate_embeddings(nodes=nodes, finetune_id=finetune_id)
embedding_service.generate_embeddings(nodes=nodes, filter_id=filter_id)
return nodes


Expand Down
41 changes: 35 additions & 6 deletions nagato/service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llama_index.node_parser import SimpleNodeParser
from numpy import ndarray
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from nagato.service.vectordb import get_vector_service

Expand All @@ -29,29 +30,57 @@ def generate_documents(self) -> List[Document]:
suffix=self.get_datasource_suffix(), delete=True
) as temp_file:
if self.url:
content = requests.get(self.url).content
response = requests.get(self.url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024
progress_bar = tqdm(
total=total_size_in_bytes,
desc="Downloading file",
unit="iB",
unit_scale=True,
)
content = b""
for data in response.iter_content(block_size):
progress_bar.update(len(data))
content += data
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
else:
content = self.content
temp_file.write(content)
temp_file.flush()
reader = SimpleDirectoryReader(input_files=[temp_file.name])
docs = reader.load_data()

with tqdm(total=3, desc="Processing data") as pbar:
pbar.update()
pbar.set_description("Analyzing data")
reader = SimpleDirectoryReader(input_files=[temp_file.name])
pbar.update()
pbar.set_description("Generating documents")
docs = reader.load_data()
pbar.update()
pbar.set_description("Documents generated")

return docs

def generate_chunks(self, documents: List[Document]) -> List[Union[Document, None]]:
parser = SimpleNodeParser.from_defaults(chunk_size=350, chunk_overlap=20)
nodes = parser.get_nodes_from_documents(documents, show_progress=True)
with tqdm(total=2, desc="Generating chunks") as pbar:
pbar.update()
pbar.set_description("Generating nodes")
nodes = parser.get_nodes_from_documents(documents, show_progress=True)
pbar.update()
return nodes

def generate_embeddings(
self,
nodes: List[Union[Document, None]],
finetune_id: str,
filter_id: str,
) -> List[ndarray]:
vectordb = get_vector_service(
provider="pinecone",
index_name="all-minilm-l6-v2",
namespace=finetune_id,
filter_id=filter_id,
dimension=384,
)
model = SentenceTransformer(
Expand Down
16 changes: 8 additions & 8 deletions nagato/service/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


class VectorDBService(ABC):
def __init__(self, index_name: str, dimension: int, namespace: str = None):
def __init__(self, index_name: str, dimension: int, filter_id: str = None):
self.index_name = index_name
self.namespace = namespace
self.filter_id = filter_id
self.dimension = dimension

@abstractmethod
Expand All @@ -22,9 +22,9 @@ def query():


class PineconeVectorService(VectorDBService):
def __init__(self, index_name: str, dimension: int, namespace: str = None):
def __init__(self, index_name: str, dimension: int, filter_id: str = None):
super().__init__(
index_name=index_name, dimension=dimension, namespace=namespace
index_name=index_name, dimension=dimension, filter_id=filter_id
)
pinecone.init(api_key=config("PINECONE_API_KEY"))
# Create a new vector index if it doesn't
Expand All @@ -36,19 +36,19 @@ def __init__(self, index_name: str, dimension: int, namespace: str = None):
self.index = pinecone.Index(index_name=self.index_name)

def upsert(self, vectors: ndarray):
self.index.upsert(vectors=vectors, namespace=self.namespace)
self.index.upsert(vectors=vectors, namespace=self.filter_id)

def query(self, queries: List[ndarray], top_k: int, include_metadata: bool = True):
return self.index.query(
queries=queries,
top_k=top_k,
include_metadata=include_metadata,
namespace=self.namespace,
namespace=self.filter_id,
)


def get_vector_service(
provider: str, index_name: str, namespace: str = None, dimension: int = 384
provider: str, index_name: str, filter_id: str = None, dimension: int = 384
):
services = {
"pinecone": PineconeVectorService,
Expand All @@ -58,4 +58,4 @@ def get_vector_service(
service = services.get(provider)
if service is None:
raise ValueError(f"Unsupported provider: {provider}")
return service(index_name=index_name, namespace=namespace, dimension=dimension)
return service(index_name=index_name, filter_id=filter_id, dimension=dimension)

0 comments on commit 529240b

Please sign in to comment.