From 529240b4faf11fe1925918f6700925f4f5437020 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Sun, 22 Oct 2023 10:59:46 -0700 Subject: [PATCH] Clean up embedding API --- nagato/service/__init__.py | 4 ++-- nagato/service/embedding.py | 41 +++++++++++++++++++++++++++++++------ nagato/service/vectordb.py | 16 +++++++-------- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/nagato/service/__init__.py b/nagato/service/__init__.py index 61db605..1f43a6a 100644 --- a/nagato/service/__init__.py +++ b/nagato/service/__init__.py @@ -8,12 +8,12 @@ def create_vector_embeddings( - type: str, finetune_id: str, url: str = None, content: str = None + type: str, filter_id: str, url: str = None, content: str = None ) -> List[Union[Document, None]]: embedding_service = EmbeddingService(type=type, content=content, url=url) documents = embedding_service.generate_documents() nodes = embedding_service.generate_chunks(documents=documents) - embedding_service.generate_embeddings(nodes=nodes, finetune_id=finetune_id) + embedding_service.generate_embeddings(nodes=nodes, filter_id=filter_id) return nodes diff --git a/nagato/service/embedding.py b/nagato/service/embedding.py index 977c87f..665600c 100644 --- a/nagato/service/embedding.py +++ b/nagato/service/embedding.py @@ -7,6 +7,7 @@ from llama_index.node_parser import SimpleNodeParser from numpy import ndarray from sentence_transformers import SentenceTransformer +from tqdm import tqdm from nagato.service.vectordb import get_vector_service @@ -29,29 +30,57 @@ def generate_documents(self) -> List[Document]: suffix=self.get_datasource_suffix(), delete=True ) as temp_file: if self.url: - content = requests.get(self.url).content + response = requests.get(self.url, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 + progress_bar = tqdm( + total=total_size_in_bytes, + desc="Downloading file", + unit="iB", + unit_scale=True, + ) + content = b"" + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + content += data + progress_bar.close() + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + print("ERROR, something went wrong") else: content = self.content temp_file.write(content) temp_file.flush() - reader = SimpleDirectoryReader(input_files=[temp_file.name]) - docs = reader.load_data() + + with tqdm(total=3, desc="Processing data") as pbar: + pbar.update() + pbar.set_description("Analyzing data") + reader = SimpleDirectoryReader(input_files=[temp_file.name]) + pbar.update() + pbar.set_description("Generating documents") + docs = reader.load_data() + pbar.update() + pbar.set_description("Documents generated") + return docs def generate_chunks(self, documents: List[Document]) -> List[Union[Document, None]]: parser = SimpleNodeParser.from_defaults(chunk_size=350, chunk_overlap=20) - nodes = parser.get_nodes_from_documents(documents, show_progress=True) + with tqdm(total=2, desc="Generating chunks") as pbar: + pbar.update() + pbar.set_description("Generating nodes") + nodes = parser.get_nodes_from_documents(documents, show_progress=True) + pbar.update() return nodes def generate_embeddings( self, nodes: List[Union[Document, None]], - finetune_id: str, + filter_id: str, ) -> List[ndarray]: vectordb = get_vector_service( provider="pinecone", index_name="all-minilm-l6-v2", - namespace=finetune_id, + filter_id=filter_id, dimension=384, ) model = SentenceTransformer( diff --git a/nagato/service/vectordb.py b/nagato/service/vectordb.py index 078c3b0..eb34fae 100644 --- a/nagato/service/vectordb.py +++ b/nagato/service/vectordb.py @@ -7,9 +7,9 @@ class VectorDBService(ABC): - def __init__(self, index_name: str, dimension: int, namespace: str = None): + def __init__(self, index_name: str, dimension: int, filter_id: str = None): self.index_name = index_name - self.namespace = namespace + self.filter_id = filter_id self.dimension = dimension @abstractmethod @@ -22,9 +22,9 @@ def query(): class PineconeVectorService(VectorDBService): - def __init__(self, index_name: str, dimension: int, namespace: str = None): + def __init__(self, index_name: str, dimension: int, filter_id: str = None): super().__init__( - index_name=index_name, dimension=dimension, namespace=namespace + index_name=index_name, dimension=dimension, filter_id=filter_id ) pinecone.init(api_key=config("PINECONE_API_KEY")) # Create a new vector index if it doesn't @@ -36,19 +36,19 @@ def __init__(self, index_name: str, dimension: int, namespace: str = None): self.index = pinecone.Index(index_name=self.index_name) def upsert(self, vectors: ndarray): - self.index.upsert(vectors=vectors, namespace=self.namespace) + self.index.upsert(vectors=vectors, namespace=self.filter_id) def query(self, queries: List[ndarray], top_k: int, include_metadata: bool = True): return self.index.query( queries=queries, top_k=top_k, include_metadata=include_metadata, - namespace=self.namespace, + namespace=self.filter_id, ) def get_vector_service( - provider: str, index_name: str, namespace: str = None, dimension: int = 384 + provider: str, index_name: str, filter_id: str = None, dimension: int = 384 ): services = { "pinecone": PineconeVectorService, @@ -58,4 +58,4 @@ def get_vector_service( service = services.get(provider) if service is None: raise ValueError(f"Unsupported provider: {provider}") - return service(index_name=index_name, namespace=namespace, dimension=dimension) + return service(index_name=index_name, filter_id=filter_id, dimension=dimension)