From a4ab39811cee57fac3a7a8ebcc320ebf578f06aa Mon Sep 17 00:00:00 2001 From: "Aivin V. Solatorio" Date: Sun, 25 Feb 2024 23:38:50 -0500 Subject: [PATCH] Use GIST Embeddings (#28) * Use langchain-community Signed-off-by: Aivin V. Solatorio * Update qdrant langchain_community Signed-off-by: Aivin V. Solatorio * Use GIST Embedding for docs Signed-off-by: Aivin V. Solatorio * Use GIST embedding for indicators Signed-off-by: Aivin V. Solatorio --------- Signed-off-by: Aivin V. Solatorio --- llm4data/embeddings/base.py | 8 ++++++-- llm4data/embeddings/docs.py | 2 +- llm4data/embeddings/indicators.py | 16 +++++++++++----- llm4data/index/qdrant.py | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/llm4data/embeddings/base.py b/llm4data/embeddings/base.py index 4aa1e68..d319608 100644 --- a/llm4data/embeddings/base.py +++ b/llm4data/embeddings/base.py @@ -1,6 +1,6 @@ """Base classes for embedding models.""" from typing import Union, Optional -from langchain import embeddings as langchain_embeddings +from langchain_community import embeddings as langchain_embeddings from pydantic.main import ModelMetaclass from qdrant_client.http import models from pydantic.main import ModelMetaclass @@ -17,6 +17,10 @@ class BaseEmbeddingModel: "instruct": 768, "all-MiniLM-L6-v2": 384, "multi-qa-mpnet-base-dot-v1": 768, + "avsolatorio/GIST-all-MiniLM-L6-v2": 384, + "avsolatorio/GIST-small-Embedding-v0": 384, + "avsolatorio/GIST-Embedding-v0": 768, + "avsolatorio/GIST-large-Embedding-v0": 384, } model_name: str distance: Union[str, models.Distance] @@ -36,7 +40,7 @@ class BaseEmbeddingModel: @property def model_id(self): - return f"{self.data_type}_{self.model_name}_{self.collection_name}_{self.distance}_{self.size}_{self.max_tokens}_{self.is_instruct}" + return f"{self.data_type}_{self.model_name.replace('/', '_')}_{self.collection_name}_{self.distance}_{self.size}_{self.max_tokens}_{self.is_instruct}" def dict(self): return asdict(self) diff --git a/llm4data/embeddings/docs.py b/llm4data/embeddings/docs.py index 598dd8f..97b2555 100644 --- a/llm4data/embeddings/docs.py +++ b/llm4data/embeddings/docs.py @@ -15,7 +15,7 @@ def get_docs_embeddings(): if DOCS_EMBEDDINGS is None: DOCS_EMBEDDINGS = DocsEmbedding( - model_name="all-MiniLM-L6-v2", + model_name="avsolatorio/GIST-small-Embedding-v0", distance="Cosine", embedding_cls="HuggingFaceEmbeddings", is_instruct=False, diff --git a/llm4data/embeddings/indicators.py b/llm4data/embeddings/indicators.py index dbbb2c7..7b36185 100644 --- a/llm4data/embeddings/indicators.py +++ b/llm4data/embeddings/indicators.py @@ -14,13 +14,19 @@ def get_indicators_embeddings(): global INDICATORS_EMBEDDINGS if INDICATORS_EMBEDDINGS is None: + # INDICATORS_EMBEDDINGS = IndicatorsEmbedding( + # model_name="instruct", + # distance="Cosine", + # embedding_cls="HuggingFaceInstructEmbeddings", + # is_instruct=True, + # embed_instruction="Represent the Economic Development description for retrieval; Input: ", + # query_instruction="Represent the Economic Development prompt for retrieving descriptions; Input: ", + # ) INDICATORS_EMBEDDINGS = IndicatorsEmbedding( - model_name="instruct", + model_name="avsolatorio/GIST-small-Embedding-v0", distance="Cosine", - embedding_cls="HuggingFaceInstructEmbeddings", - is_instruct=True, - embed_instruction="Represent the Economic Development description for retrieval; Input: ", - query_instruction="Represent the Economic Development prompt for retrieving descriptions; Input: ", + embedding_cls="HuggingFaceEmbeddings", + is_instruct=False, ) return INDICATORS_EMBEDDINGS diff --git a/llm4data/index/qdrant.py b/llm4data/index/qdrant.py index e51c7bf..dcde95e 100644 --- a/llm4data/index/qdrant.py +++ b/llm4data/index/qdrant.py @@ -1,6 +1,6 @@ import os from typing import Optional, Union -from langchain.vectorstores import Qdrant +from langchain_community.vectorstores import Qdrant import qdrant_client from qdrant_client.http import models