diff --git a/src/agrag/agrag.py b/src/agrag/agrag.py index 5c949cc8..c49c7da3 100644 --- a/src/agrag/agrag.py +++ b/src/agrag/agrag.py @@ -186,7 +186,7 @@ def initialize_data_module(self): def initialize_embeddings_module(self): """Initializes the Embedding module.""" self.embedding_module = EmbeddingModule( - hf_model=self.args.hf_embedding_model, + model_name=self.args.hf_embedding_model, pooling_strategy=self.args.pooling_strategy, normalize_embeddings=self.args.normalize_embeddings, hf_model_params=self.args.hf_model_params, @@ -195,6 +195,9 @@ def initialize_embeddings_module(self): hf_forward_params=self.args.hf_forward_params, normalization_params=self.args.normalization_params, query_instruction_for_retrieval=self.args.query_instruction_for_retrieval, + use_bedrock=self.args.embedding_use_bedrock, + bedrock_embedding_params=self.args.bedrock_embedding_params, + bedrock_aws_region=self.args.embedding_bedrock_aws_region, ) logger.info("Embedding module initialized") diff --git a/src/agrag/args.py b/src/agrag/args.py index 9f3bad1e..426cf85d 100644 --- a/src/agrag/args.py +++ b/src/agrag/args.py @@ -258,6 +258,30 @@ def embedding_batch_size(self): def embedding_batch_size(self, value): self.config["embedding"]["embedding_batch_size"] = value + @property + def embedding_use_bedrock(self): + return self.config.get("embedding", {}).get("use_bedrock", self.embedding_defaults.get("USE_BEDROCK")) + + @embedding_use_bedrock.setter + def embedding_use_bedrock(self, value): + self.config["embedding"]["use_bedrock"] = value + + @property + def bedrock_embedding_params(self): + return self.config.get("embedding", {}).get("bedrock_embedding_params", {}) + + @bedrock_embedding_params.setter + def bedrock_embedding_params(self, value): + self.config["embedding"]["bedrock_embedding_params"] = value + + @property + def embedding_bedrock_aws_region(self): + return self.config.get("embedding", {}).get("bedrock_aws_region", {}) + + @embedding_bedrock_aws_region.setter + def embedding_bedrock_aws_region(self, value): + self.config["embedding"]["bedrock_aws_region"] = value + @property def vector_db_type(self): return self.config.get("vector_db", {}).get("db_type", self.vector_db_defaults.get("DB_TYPE")) diff --git a/src/agrag/configs/embedding/default.yaml b/src/agrag/configs/embedding/default.yaml index c0573f0a..b55aecd3 100644 --- a/src/agrag/configs/embedding/default.yaml +++ b/src/agrag/configs/embedding/default.yaml @@ -1,5 +1,5 @@ DEFAULT_EMBEDDING_MODEL: "BAAI/bge-large-en" -POOLING_STRATEGY: None +POOLING_STRATEGY: null NORMALIZE_EMBEDDINGS: False HF_TOKENIZER_PARAMS: {"truncation": True, "padding": True} EMBEDDING_BATCH_SIZE: 128 diff --git a/src/agrag/configs/presets/medium_quality_config.yaml b/src/agrag/configs/presets/medium_quality_config.yaml index 47c95dc7..76baafdd 100644 --- a/src/agrag/configs/presets/medium_quality_config.yaml +++ b/src/agrag/configs/presets/medium_quality_config.yaml @@ -9,8 +9,9 @@ data: - "txt" embedding: - embedding_model: BAAI/bge-large-en - pooling_strategy: cls + embedding_model: cohere.embed-english-v3 + use_bedrock: true + bedrock_aws_region: us-west-2 normalize_embeddings: false hf_tokenizer_params: truncation: true diff --git a/src/agrag/modules/embedding/README.md b/src/agrag/modules/embedding/README.md index 0d87a734..ac9726ae 100644 --- a/src/agrag/modules/embedding/README.md +++ b/src/agrag/modules/embedding/README.md @@ -6,7 +6,7 @@ Here are the configurable parameters for this module: ``` embedding: - embedding_model: The name of the Huggingface model to use for generating embeddings (default is "BAAI/bge-large-en"). + embedding_model: The name of the Huggingface or Bedrock model to use for generating embeddings (default is "BAAI/bge-large-en" from Huggingface). Currently only Amazon Titan and Cohere Embedding models are supported on Bedrock. pooling_strategy: The strategy to use for pooling embeddings. Options are 'mean', 'max', 'cls' (default is None). @@ -23,4 +23,11 @@ embedding: normalization_params: Additional parameters to pass to the PyTorch `nn. functional.normalize` method. query_instruction_for_retrieval: Instruction for query when using embedding model. + + use_bedrock: Whether to use the provided model from AWS Bedrock API. https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + Currently only Cohere (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html) and Amazon Titan (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan.html) embedding models are supported. + + bedrock_embedding_params: Additional parameters to pass into the model when generating the embeddings. + + bedrock_aws_region: AWS region where the model is hosted on Bedrock. ``` \ No newline at end of file diff --git a/src/agrag/modules/embedding/embedding.py b/src/agrag/modules/embedding/embedding.py index 693ed82f..aec30b2b 100644 --- a/src/agrag/modules/embedding/embedding.py +++ b/src/agrag/modules/embedding/embedding.py @@ -1,6 +1,8 @@ +import json import logging from typing import List, Union +import boto3 import numpy as np import pandas as pd import torch @@ -9,7 +11,7 @@ from transformers import AutoModel, AutoTokenizer from agrag.constants import DOC_TEXT_KEY, EMBEDDING_HIDDEN_DIM_KEY, EMBEDDING_KEY -from agrag.modules.embedding.utils import normalize_embedding, pool +from agrag.modules.embedding.utils import get_embeddings_bedrock, normalize_embedding, pool logger = logging.getLogger("rag-logger") @@ -20,8 +22,8 @@ class EmbeddingModule: Attributes: ---------- - hf_model : str - The name of the Huggingface model to use for generating embeddings (default is "BAAI/bge-large-en"). + model_name : str + The name of the Huggingface or Bedrock model to use for generating embeddings (default is "BAAI/bge-large-en" from Huggingface). pooling_strategy : str The strategy to use for pooling embeddings. Options are 'mean', 'max', 'cls' (default is None). normalize_embeddings: bool @@ -38,6 +40,13 @@ class EmbeddingModule: Additional parameters to pass to the PyTorch `nn.functional.normalize` method. query_instruction_for_retrieval: str Instruction for query when using embedding model. + use_bedrock: str + Whether to use the provided model from AWS Bedrock API. https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + Currently only Cohere (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html) and Amazon Titan (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan.html) embedding models are supported. + bedrock_embedding_params: dict + Additional parameters to pass into the model when generating the embeddings. + bedrock_aws_region: str + AWS region where the model is hosted on Bedrock. Methods: ------- @@ -50,12 +59,12 @@ class EmbeddingModule: def __init__( self, - hf_model: str = "BAAI/bge-large-en", + model_name: str = "BAAI/bge-large-en", pooling_strategy: str = None, normalize_embeddings: bool = False, **kwargs, ): - self.hf_model = hf_model + self.model_name = model_name self.pooling_strategy = pooling_strategy self.normalize_embeddings = normalize_embeddings self.hf_model_params = kwargs.get("hf_model_params", {}) @@ -66,14 +75,25 @@ def __init__( self.query_instruction_for_retrieval = kwargs.get("query_instruction_for_retrieval", None) self.num_gpus = kwargs.get("num_gpus", 0) self.device = "cpu" if not self.num_gpus else torch.device("cuda" if torch.cuda.is_available() else "cpu") - - logger.info(f"Using Huggingface Model {self.hf_model} for Embedding Module") - self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model, **self.hf_tokenizer_init_params) - self.model = AutoModel.from_pretrained(self.hf_model, **self.hf_model_params) - if self.num_gpus > 1: - logger.info(f"Using {self.num_gpus} GPUs") - self.model = DataParallel(self.model) - self.model.to(self.device) + self.use_bedrock = kwargs.get("use_bedrock") + if self.use_bedrock: + if not "embed" in self.model_name: + raise ValueError( + f"Invalid model_id {self.model_name}. Must use an embedding model from Bedrock. The model_id should contain 'embed'." + ) + logger.info(f"Using Bedrock Model {self.model_name} for Embedding Module") + self.bedrock_embedding_params = kwargs.get("bedrock_embedding_params", {}) + if "cohere" in self.model_name: + self.bedrock_embedding_params["input_type"] = "search_document" + self.client = boto3.client("bedrock-runtime", region_name=kwargs.get("bedrock_aws_region", None)) + else: + logger.info(f"Using Huggingface Model {self.model_name} for Embedding Module") + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, **self.hf_tokenizer_init_params) + self.model = AutoModel.from_pretrained(self.model_name, **self.hf_model_params) + if self.num_gpus > 1: + logger.info(f"Using {self.num_gpus} GPUs") + self.model = DataParallel(self.model) + self.model.to(self.device) def encode(self, data: pd.DataFrame, pbar: tqdm = None, batch_size: int = 32) -> pd.DataFrame: """ @@ -112,24 +132,35 @@ def encode(self, data: pd.DataFrame, pbar: tqdm = None, batch_size: int = 32) -> logger.info("\nTokenizing text chunks") batch_texts = texts[i : i + batch_size] - inputs = self.tokenizer(batch_texts, return_tensors="pt", **self.hf_tokenizer_params) - inputs = {k: v.to(self.device) for k, v in inputs.items()} logger.info("\nGenerating embeddings") - with torch.no_grad(): - outputs = self.model(**inputs, **self.hf_forward_params) + if self.use_bedrock: + batch_embeddings = get_embeddings_bedrock( + batch_texts=batch_texts, + client=self.client, + model_id=self.model_name, + embedding_params=self.bedrock_embedding_params, + ) + else: + inputs = self.tokenizer(batch_texts, return_tensors="pt", **self.hf_tokenizer_params) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + with torch.no_grad(): + outputs = self.model(**inputs, **self.hf_forward_params) - logger.info("\nProcessing embeddings") + # The first element in the tuple returned by the model is the embeddings generated + # The tuple elements are (embeddings, hidden_states, past_key_values, attentions, cross_attentions) + batch_embeddings = outputs[0] - # The first element in the tuple returned by the model is the embeddings generated - # The tuple elements are (embeddings, hidden_states, past_key_values, attentions, cross_attentions) - batch_embeddings = outputs[0] + logger.info("\nProcessing embeddings") batch_embeddings = pool(batch_embeddings, self.pooling_strategy) if self.normalize_embeddings: batch_embeddings = normalize_embedding(batch_embeddings, **self.normalization_params) - batch_embeddings = batch_embeddings.cpu().numpy() + if isinstance(batch_embeddings, torch.Tensor): + batch_embeddings = batch_embeddings.cpu().numpy() + else: + batch_embeddings = np.array(batch_embeddings) all_embeddings.extend(batch_embeddings) all_embeddings_hidden_dim.extend([batch_embeddings.shape[-1]] * batch_embeddings.shape[0]) diff --git a/src/agrag/modules/embedding/utils.py b/src/agrag/modules/embedding/utils.py index 0b07577f..89c2e1d3 100644 --- a/src/agrag/modules/embedding/utils.py +++ b/src/agrag/modules/embedding/utils.py @@ -1,8 +1,13 @@ -from typing import List +import logging +from typing import Dict, List +import boto3 import torch from torch.nn import functional as F +logger = logging.getLogger("rag-logger") +import json + def pool(embeddings: List[torch.Tensor], pooling_strategy: str) -> List[torch.Tensor]: """ @@ -73,3 +78,70 @@ def normalize_embedding(embeddings, args=None): normalized_embeddings = normalize(embeddings, args) """ return F.normalize(embeddings, **args) + + +def get_embeddings_bedrock( + batch_texts: List[str], client: boto3.client, model_id: str, embedding_params: dict = {} +) -> List[float]: + embeddings = [] + if "titan" in model_id: + for text in batch_texts: + body = json.dumps( + { + "inputText": text, + **embedding_params, + } + ) + response = client.invoke_model( + body=body, + modelId=model_id, + accept="application/json", + contentType="application/json", + ) + outputs = json.loads(response["body"].read()) + embeddings.append(outputs.get("embedding")) + elif "cohere" in model_id: + body = json.dumps( + { + "texts": batch_texts, + **embedding_params, + } + ) + response = client.invoke_model( + body=body, + modelId=model_id, + accept="application/json", + contentType="application/json", + ) + outputs = json.loads(response["body"].read()) + embeddings = outputs.get("embeddings") + else: + raise NotImplementedError(f"Unsupported Embedding Model for Bedrock {model_id}") + return embeddings + + +def extract_response(output: Dict) -> str: + """ + Extracts the response embeddings from the model output. + + Parameters: + ---------- + output : Dict + The output dictionary from the Bedrock model. + + Returns: + ------- + str + The extracted response text. + """ + # Used for Mistral response + if "outputs" in output and isinstance(output["outputs"], list) and "embedding" in output["outputs"][0]: + return output["outputs"][0]["embedding"] + # Used for Anthropic response + elif "content" in output and output["type"] == "message": + return output["content"][0]["embedding"] + # Used for Llama response + elif "generation" in output: + return output["generation"] + else: + raise ValueError("Unknown output structure: %s", output) diff --git a/src/agrag/modules/generator/generators/bedrock_generator.py b/src/agrag/modules/generator/generators/bedrock_generator.py index dfe0aad5..cf3dd45e 100644 --- a/src/agrag/modules/generator/generators/bedrock_generator.py +++ b/src/agrag/modules/generator/generators/bedrock_generator.py @@ -91,5 +91,4 @@ def extract_response(self, output: Dict) -> str: elif "generation" in output: return output["generation"].strip() else: - logger.error("Unknown output structure: %s", output) - return "" + raise ValueError("Unknown output structure: %s", output) diff --git a/src/agrag/modules/retriever/retrievers/retriever_base.py b/src/agrag/modules/retriever/retrievers/retriever_base.py index e5fd4c9b..22860545 100644 --- a/src/agrag/modules/retriever/retrievers/retriever_base.py +++ b/src/agrag/modules/retriever/retrievers/retriever_base.py @@ -66,6 +66,8 @@ def encode_query(self, query: str) -> np.ndarray: np.ndarray The embedding of the query. """ + if self.embedding_module.use_bedrock and "cohere" in self.embedding_module.model_name: + self.embedding_module.bedrock_embedding_params["input_type"] = "search_query" query_embedding = self.embedding_module.encode(data=pd.DataFrame([{DOC_TEXT_KEY: query}])) query_embedding = query_embedding[EMBEDDING_KEY][0] return query_embedding diff --git a/tests/unittests/embedding/test_embedding.py b/tests/unittests/embedding/test_embedding.py index cb37ed49..d8fae7eb 100644 --- a/tests/unittests/embedding/test_embedding.py +++ b/tests/unittests/embedding/test_embedding.py @@ -19,15 +19,15 @@ def setUp(self, mock_model, mock_tokenizer): mock_tokenizer.return_value = self.mock_tokenizer mock_model.return_value = self.mock_model - hf_model_params = {"param": "param"} + model_name_params = {"param": "param"} hf_tokenizer_params = {"param": True} tokenizer_params = {"padding": 10, "max_length": 512} forward_params = {"param": True} self.embedding_module = EmbeddingModule( - hf_model="some-model", + model_name="some-model", pooling_strategy=None, - hf_model_params=hf_model_params, + model_name_params=model_name_params, hf_tokenizer_init_params=hf_tokenizer_params, hf_tokenizer_params=tokenizer_params, hf_forward_params=forward_params,