Skip to content

Commit

Permalink
Use fastembed (#22)
Browse files Browse the repository at this point in the history
* Use fastembed

* Minor tweaks

* Fix formatting
  • Loading branch information
homanp authored Feb 1, 2024
1 parent d035e58 commit 26c38cf
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 63 deletions.
5 changes: 3 additions & 2 deletions api/delete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from models.delete import RequestPayload, ResponsePayload
from service.vector_database import get_vector_service, VectorService

from auth.user import get_current_api_user
from models.delete import RequestPayload, ResponsePayload
from service.vector_database import VectorService, get_vector_service

router = APIRouter()

Expand Down
8 changes: 4 additions & 4 deletions api/ingest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import requests

from typing import Dict

import requests
from fastapi import APIRouter, Depends

from auth.user import get_current_api_user
from models.ingest import RequestPayload
from service.embedding import EmbeddingService
from auth.user import get_current_api_user


router = APIRouter()

Expand Down
5 changes: 3 additions & 2 deletions api/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from models.query import RequestPayload, ResponsePayload
from service.vector_database import get_vector_service, VectorService

from auth.user import get_current_api_user
from models.query import RequestPayload, ResponsePayload
from service.vector_database import VectorService, get_vector_service

router = APIRouter()

Expand Down
2 changes: 1 addition & 1 deletion auth/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import jwt

import jwt
from decouple import config
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from decouple import config
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from decouple import config
from router import router

from router import router

app = FastAPI(
title="SuperRag",
Expand Down
1 change: 1 addition & 0 deletions models/delete.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel

from models.vector_database import VectorDatabase


Expand Down
1 change: 1 addition & 0 deletions models/file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum

from pydantic import BaseModel


Expand Down
2 changes: 2 additions & 0 deletions models/ingest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Optional

from pydantic import BaseModel

from models.file import File
from models.vector_database import VectorDatabase

Expand Down
4 changes: 3 additions & 1 deletion models/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel
from typing import List, Optional

from pydantic import BaseModel

from models.vector_database import VectorDatabase


Expand Down
3 changes: 2 additions & 1 deletion models/vector_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from enum import Enum
from typing import Dict

from pydantic import BaseModel


Expand Down
10 changes: 7 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cohere==4.42
coloredlogs==15.0.1
cryptography==41.0.7
dataclasses-json==0.6.3
Deprecated==1.2.14
Expand All @@ -23,7 +24,9 @@ dnspython==2.4.2
docx2txt==0.8
fastapi==0.109.0
fastavro==1.9.3
fastembed==0.1.3
filelock==3.13.1
flatbuffers==23.5.26
frozenlist==1.4.1
fsspec==2023.12.2
geomet==0.2.1.post1
Expand All @@ -37,13 +40,13 @@ hpack==4.0.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.25.2
huggingface-hub==0.20.2
huggingface-hub==0.19.4
humanfriendly==10.0
hyperframe==6.0.1
idna==3.6
importlib-metadata==6.11.0
Jinja2==3.1.3
joblib==1.3.2
litellm==1.17.5
llama-index==0.9.30
loguru==0.7.2
lxml==5.1.0
Expand All @@ -56,6 +59,8 @@ nest-asyncio==1.5.8
networkx==3.2.1
nltk==3.8.1
numpy==1.26.3
onnx==1.15.0
onnxruntime==1.17.0
openai==1.7.2
packaging==23.2
pandas==2.1.4
Expand Down Expand Up @@ -94,7 +99,6 @@ tokenizers==0.15.0
toml==0.10.2
torch==2.1.2
tqdm==4.66.1
transformers==4.36.2
typing-inspect==0.9.0
typing_extensions==4.9.0
tzdata==2023.4
Expand Down
2 changes: 1 addition & 1 deletion router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from api import ingest, query, delete
from api import delete, ingest, query

router = APIRouter()
api_prefix = "/api/v1"
Expand Down
24 changes: 10 additions & 14 deletions service/embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import requests
import asyncio

from typing import Any, List, Union
from tempfile import NamedTemporaryFile
from typing import Any, List, Union

import numpy as np
import requests
from fastembed.embedding import FlagEmbedding as Embedding
from llama_index import Document, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from litellm import aembedding

from models.file import File
from decouple import config
from service.vector_database import get_vector_service


Expand Down Expand Up @@ -58,18 +59,13 @@ async def generate_embeddings(
) -> List[tuple[str, list, dict[str, Any]]]:
async def generate_embedding(node):
if node is not None:
vectors = []
embedding_object = await aembedding(
model="huggingface/intfloat/multilingual-e5-large",
input=node.text,
api_key=config("HUGGINGFACE_API_KEY"),
embedding_model = Embedding(
model_name="sentence-transformers/all-MiniLM-L6-v2", max_length=512
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
embeddings: List[np.ndarray] = list(embedding_model.embed(node.text))
embedding = (
node.id_,
vectors,
embeddings[0].tolist(),
{
**node.metadata,
"content": node.text,
Expand Down
44 changes: 12 additions & 32 deletions service/vector_database.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import weaviate

from abc import ABC, abstractmethod
from typing import Any, List, Type

import numpy as np
import weaviate
from astrapy.db import AstraDB
from decouple import config
from litellm import embedding
from fastembed.embedding import FlagEmbedding as Embedding
from pinecone import Pinecone, ServerlessSpec
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from pinecone import Pinecone, ServerlessSpec
from astrapy.db import AstraDB

from models.vector_database import VectorDatabase

Expand Down Expand Up @@ -35,16 +36,11 @@ async def delete(self, file_url: str):
pass

async def _generate_vectors(sefl, input: str):
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
input=input,
api_key=config("HUGGINGFACE_API_KEY"),
embedding_model = Embedding(
model_name="sentence-transformers/all-MiniLM-L6-v2", max_length=512
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
return vectors
embeddings: List[np.ndarray] = list(embedding_model.embed(input))
return embeddings[0].tolist()

async def rerank(self, query: str, documents: list, top_n: int = 4):
from cohere import Client
Expand Down Expand Up @@ -97,15 +93,7 @@ async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
self.index.upsert(vectors=embeddings)

async def query(self, input: str, top_k: 4, include_metadata: bool = True):
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
input=input,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
vectors = await self._generate_vectors(input=input)
results = self.index.query(
vector=vectors,
top_k=top_k,
Expand Down Expand Up @@ -164,15 +152,7 @@ async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> No
self.client.upsert(collection_name=self.index_name, wait=True, points=points)

async def query(self, input: str, top_k: int) -> List:
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
input=input,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
vectors = await self._generate_vectors(input=input)
search_result = self.client.search(
collection_name=self.index_name,
query_vector=("content", vectors),
Expand Down

0 comments on commit 26c38cf

Please sign in to comment.