Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for dynamic providers in Astra DB Comp #4627

Merged
merged 10 commits into from
Nov 18, 2024
123 changes: 85 additions & 38 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from collections import defaultdict

import orjson
from astrapy import DataAPIClient
from astrapy.admin import parse_api_endpoint
from langchain_astradb import AstraDBVectorStore

Expand Down Expand Up @@ -29,39 +31,45 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):

_cached_vector_store: AstraDBVectorStore | None = None

VECTORIZE_PROVIDERS_MAPPING = {
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
VECTORIZE_PROVIDERS_MAPPING = defaultdict(
list,
{
"Azure OpenAI": [
"azureOpenAI",
["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
],
],
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
],
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
}
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
},
)

inputs = [
SecretStrInput(
Expand Down Expand Up @@ -109,7 +117,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
value="Embedding Model",
),
HandleInput(
name="embedding",
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
Expand Down Expand Up @@ -247,15 +255,52 @@ def insert_in_dict(self, build_config, field_name, new_parameters):

return build_config

def update_providers_mapping(self):
# If we don't have token or api_endpoint, we can't fetch the list of providers
if not self.token or not self.api_endpoint:
self.log("Astra DB token and API endpoint are required to fetch the list of Vectorize providers.")

return self.VECTORIZE_PROVIDERS_MAPPING

try:
self.log("Dynamically updating list of Vectorize providers.")

# Get the admin object
client = DataAPIClient(token=self.token)
admin = client.get_admin()

# Get the embedding providers
db_admin = admin.get_database_admin(self.api_endpoint)
embedding_providers = db_admin.find_embedding_providers().as_dict()

vectorize_providers_mapping = {}

# Map the provider display name to the provider key and models
for provider_key, provider_data in embedding_providers["embeddingProviders"].items():
display_name = provider_data["displayName"]
models = [model["name"] for model in provider_data["models"]]

vectorize_providers_mapping[display_name] = [provider_key, models]

# Sort the resulting dictionary
return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))
except Exception as e: # noqa: BLE001
self.log(f"Error fetching Vectorize providers: {e}")

return self.VECTORIZE_PROVIDERS_MAPPING

def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name == "embedding_choice":
if field_value == "Astra Vectorize":
self.del_fields(build_config, ["embedding"])
self.del_fields(build_config, ["embedding_model"])

# Update the providers mapping
vectorize_providers = self.update_providers_mapping()

new_parameter = DropdownInput(
name="embedding_provider",
display_name="Embedding Provider",
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
options=vectorize_providers.keys(),
value="",
required=True,
real_time_refresh=True,
Expand All @@ -276,21 +321,23 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:
)

new_parameter = HandleInput(
name="embedding",
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()

self.insert_in_dict(build_config, "embedding_choice", {"embedding": new_parameter})
self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter})

elif field_name == "embedding_provider":
self.del_fields(
build_config,
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)

model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
# Update the providers mapping
vectorize_providers = self.update_providers_mapping()
model_options = vectorize_providers[field_value][1]

new_parameter = DropdownInput(
name="model",
Expand Down Expand Up @@ -420,7 +467,7 @@ def build_vector_store(self, vectorize_options=None):
raise ValueError(msg) from e

if self.embedding_choice == "Embedding Model":
embedding_dict = {"embedding": self.embedding}
embedding_dict = {"embedding": self.embedding_model}
else:
from astrapy.info import CollectionVectorServiceOptions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"output_types": ["Embeddings"]
},
"targetHandle": {
"fieldName": "embedding",
"fieldName": "embedding_model",
"id": "AstraDB-3buPx",
"inputTypes": ["Embeddings"],
"type": "other"
Expand Down Expand Up @@ -196,7 +196,7 @@
"output_types": ["Embeddings"]
},
"targetHandle": {
"fieldName": "embedding",
"fieldName": "embedding_model",
"id": "AstraDB-laybz",
"inputTypes": ["Embeddings"],
"type": "other"
Expand Down Expand Up @@ -1601,7 +1601,7 @@
"ingest_data",
"namespace",
"embedding_service",
"embedding",
"embedding_model",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
Expand Down Expand Up @@ -1781,23 +1781,6 @@
"type": "str",
"value": ""
},
"embedding": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"embedding_choice": {
"_input_type": "DropdownInput",
"advanced": false,
Expand All @@ -1817,6 +1800,23 @@
"type": "str",
"value": "Embedding Model"
},
"embedding_model": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding_model",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"ingest_data": {
"_input_type": "DataInput",
"advanced": false,
Expand Down Expand Up @@ -2556,7 +2556,7 @@
"ingest_data",
"namespace",
"embedding_service",
"embedding",
"embedding_model",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
Expand Down Expand Up @@ -2736,23 +2736,6 @@
"type": "str",
"value": "test"
},
"embedding": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"embedding_choice": {
"_input_type": "DropdownInput",
"advanced": false,
Expand All @@ -2772,6 +2755,23 @@
"type": "str",
"value": "Embedding Model"
},
"embedding_model": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding_model",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"ingest_data": {
"_input_type": "DataInput",
"advanced": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def ingestion_graph():
openai_embeddings = OpenAIEmbeddingsComponent()
vector_store = AstraVectorStoreComponent()
vector_store.set(
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
)

Expand All @@ -34,7 +34,7 @@ def rag_graph():
rag_vector_store = AstraVectorStoreComponent()
rag_vector_store.set(
search_input=chat_input.message_response,
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
)

parse_data = ParseDataComponent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def test_base(astradb_client: AstraDB):
"token": application_token,
"api_endpoint": api_endpoint,
"collection_name": BASIC_COLLECTION,
"embedding": ComponentInputHandle(
"embedding_model": ComponentInputHandle(
clazz=OpenAIEmbeddingsComponent,
inputs={"openai_api_key": get_openai_api_key()},
output_name="embeddings",
Expand Down Expand Up @@ -79,7 +79,7 @@ async def test_astra_embeds_and_search():
"ingest_data": ComponentInputHandle(
clazz=TextToData, inputs={"text_data": ["test1", "test2"]}, output_name="from_text"
),
"embedding": ComponentInputHandle(
"embedding_model": ComponentInputHandle(
clazz=OpenAIEmbeddingsComponent,
inputs={"openai_api_key": get_openai_api_key()},
output_name="embeddings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def ingestion_graph():
)
vector_store = AstraVectorStoreComponent(_id="vector-store-123")
vector_store.set(
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
api_endpoint="https://astra.example.com",
token="token", # noqa: S106
Expand All @@ -53,7 +53,7 @@ def rag_graph():
search_input=chat_input.message_response,
api_endpoint="https://astra.example.com",
token="token", # noqa: S106
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
)
# Mock search_documents
rag_vector_store.set_on_output(
Expand Down
Loading