Skip to content

Commit

Permalink
feat: Add support for dynamic providers in Astra DB Comp (langflow-ai…
Browse files Browse the repository at this point in the history
…#4627)

* feat: Add support for dynamic providers in Astra DB Comp

* [autofix.ci] apply automated fixes

* Make sure we return a default dict

* Rename params in starter template

* Update test_vector_store_rag.py

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and mieslep committed Nov 19, 2024
1 parent 5ff8593 commit 6c95779
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 82 deletions.
123 changes: 85 additions & 38 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from collections import defaultdict

import orjson
from astrapy import DataAPIClient
from astrapy.admin import parse_api_endpoint
from langchain_astradb import AstraDBVectorStore

Expand Down Expand Up @@ -29,39 +31,45 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):

_cached_vector_store: AstraDBVectorStore | None = None

VECTORIZE_PROVIDERS_MAPPING = {
"Azure OpenAI": ["azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
VECTORIZE_PROVIDERS_MAPPING = defaultdict(
list,
{
"Azure OpenAI": [
"azureOpenAI",
["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
],
],
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
"Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]],
"Hugging Face - Serverless": [
"huggingface",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large",
"intfloat/multilingual-e5-large-instruct",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
],
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
}
"Jina AI": [
"jinaAI",
[
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-de",
"jina-embeddings-v2-base-es",
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-zh",
],
],
"Mistral AI": ["mistral", ["mistral-embed"]],
"NVIDIA": ["nvidia", ["NV-Embed-QA"]],
"OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]],
"Upstage": ["upstageAI", ["solar-embedding-1-large"]],
"Voyage AI": [
"voyageAI",
["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"],
],
},
)

inputs = [
SecretStrInput(
Expand Down Expand Up @@ -109,7 +117,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
value="Embedding Model",
),
HandleInput(
name="embedding",
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
Expand Down Expand Up @@ -247,15 +255,52 @@ def insert_in_dict(self, build_config, field_name, new_parameters):

return build_config

def update_providers_mapping(self):
# If we don't have token or api_endpoint, we can't fetch the list of providers
if not self.token or not self.api_endpoint:
self.log("Astra DB token and API endpoint are required to fetch the list of Vectorize providers.")

return self.VECTORIZE_PROVIDERS_MAPPING

try:
self.log("Dynamically updating list of Vectorize providers.")

# Get the admin object
client = DataAPIClient(token=self.token)
admin = client.get_admin()

# Get the embedding providers
db_admin = admin.get_database_admin(self.api_endpoint)
embedding_providers = db_admin.find_embedding_providers().as_dict()

vectorize_providers_mapping = {}

# Map the provider display name to the provider key and models
for provider_key, provider_data in embedding_providers["embeddingProviders"].items():
display_name = provider_data["displayName"]
models = [model["name"] for model in provider_data["models"]]

vectorize_providers_mapping[display_name] = [provider_key, models]

# Sort the resulting dictionary
return defaultdict(list, dict(sorted(vectorize_providers_mapping.items())))
except Exception as e: # noqa: BLE001
self.log(f"Error fetching Vectorize providers: {e}")

return self.VECTORIZE_PROVIDERS_MAPPING

def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name == "embedding_choice":
if field_value == "Astra Vectorize":
self.del_fields(build_config, ["embedding"])
self.del_fields(build_config, ["embedding_model"])

# Update the providers mapping
vectorize_providers = self.update_providers_mapping()

new_parameter = DropdownInput(
name="embedding_provider",
display_name="Embedding Provider",
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
options=vectorize_providers.keys(),
value="",
required=True,
real_time_refresh=True,
Expand All @@ -276,21 +321,23 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:
)

new_parameter = HandleInput(
name="embedding",
name="embedding_model",
display_name="Embedding Model",
input_types=["Embeddings"],
info="Allows an embedding model configuration.",
).to_dict()

self.insert_in_dict(build_config, "embedding_choice", {"embedding": new_parameter})
self.insert_in_dict(build_config, "embedding_choice", {"embedding_model": new_parameter})

elif field_name == "embedding_provider":
self.del_fields(
build_config,
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)

model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]
# Update the providers mapping
vectorize_providers = self.update_providers_mapping()
model_options = vectorize_providers[field_value][1]

new_parameter = DropdownInput(
name="model",
Expand Down Expand Up @@ -420,7 +467,7 @@ def build_vector_store(self, vectorize_options=None):
raise ValueError(msg) from e

if self.embedding_choice == "Embedding Model":
embedding_dict = {"embedding": self.embedding}
embedding_dict = {"embedding": self.embedding_model}
else:
from astrapy.info import CollectionVectorServiceOptions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"output_types": ["Embeddings"]
},
"targetHandle": {
"fieldName": "embedding",
"fieldName": "embedding_model",
"id": "AstraDB-3buPx",
"inputTypes": ["Embeddings"],
"type": "other"
Expand Down Expand Up @@ -196,7 +196,7 @@
"output_types": ["Embeddings"]
},
"targetHandle": {
"fieldName": "embedding",
"fieldName": "embedding_model",
"id": "AstraDB-laybz",
"inputTypes": ["Embeddings"],
"type": "other"
Expand Down Expand Up @@ -1601,7 +1601,7 @@
"ingest_data",
"namespace",
"embedding_service",
"embedding",
"embedding_model",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
Expand Down Expand Up @@ -1781,23 +1781,6 @@
"type": "str",
"value": ""
},
"embedding": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"embedding_choice": {
"_input_type": "DropdownInput",
"advanced": false,
Expand All @@ -1817,6 +1800,23 @@
"type": "str",
"value": "Embedding Model"
},
"embedding_model": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding_model",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"ingest_data": {
"_input_type": "DataInput",
"advanced": false,
Expand Down Expand Up @@ -2556,7 +2556,7 @@
"ingest_data",
"namespace",
"embedding_service",
"embedding",
"embedding_model",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
Expand Down Expand Up @@ -2736,23 +2736,6 @@
"type": "str",
"value": "test"
},
"embedding": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"embedding_choice": {
"_input_type": "DropdownInput",
"advanced": false,
Expand All @@ -2772,6 +2755,23 @@
"type": "str",
"value": "Embedding Model"
},
"embedding_model": {
"_input_type": "HandleInput",
"advanced": false,
"display_name": "Embedding Model",
"dynamic": false,
"info": "Allows an embedding model configuration.",
"input_types": ["Embeddings"],
"list": false,
"name": "embedding_model",
"placeholder": "",
"required": false,
"show": true,
"title_case": false,
"trace_as_metadata": true,
"type": "other",
"value": ""
},
"ingest_data": {
"_input_type": "DataInput",
"advanced": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def ingestion_graph():
openai_embeddings = OpenAIEmbeddingsComponent()
vector_store = AstraVectorStoreComponent()
vector_store.set(
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
)

Expand All @@ -34,7 +34,7 @@ def rag_graph():
rag_vector_store = AstraVectorStoreComponent()
rag_vector_store.set(
search_input=chat_input.message_response,
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
)

parse_data = ParseDataComponent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def test_base(astradb_client: AstraDB):
"token": application_token,
"api_endpoint": api_endpoint,
"collection_name": BASIC_COLLECTION,
"embedding": ComponentInputHandle(
"embedding_model": ComponentInputHandle(
clazz=OpenAIEmbeddingsComponent,
inputs={"openai_api_key": get_openai_api_key()},
output_name="embeddings",
Expand Down Expand Up @@ -79,7 +79,7 @@ async def test_astra_embeds_and_search():
"ingest_data": ComponentInputHandle(
clazz=TextToData, inputs={"text_data": ["test1", "test2"]}, output_name="from_text"
),
"embedding": ComponentInputHandle(
"embedding_model": ComponentInputHandle(
clazz=OpenAIEmbeddingsComponent,
inputs={"openai_api_key": get_openai_api_key()},
output_name="embeddings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def ingestion_graph():
)
vector_store = AstraVectorStoreComponent(_id="vector-store-123")
vector_store.set(
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
ingest_data=text_splitter.split_text,
api_endpoint="https://astra.example.com",
token="token", # noqa: S106
Expand All @@ -53,7 +53,7 @@ def rag_graph():
search_input=chat_input.message_response,
api_endpoint="https://astra.example.com",
token="token", # noqa: S106
embedding=openai_embeddings.build_embeddings,
embedding_model=openai_embeddings.build_embeddings,
)
# Mock search_documents
rag_vector_store.set_on_output(
Expand Down

0 comments on commit 6c95779

Please sign in to comment.