Skip to content

Commit

Permalink
Support for Bedrock Embedding models (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyash2106 authored Jul 30, 2024
1 parent 383bda1 commit 6a4ee7b
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 33 deletions.
5 changes: 4 additions & 1 deletion src/agrag/agrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def initialize_data_module(self):
def initialize_embeddings_module(self):
"""Initializes the Embedding module."""
self.embedding_module = EmbeddingModule(
hf_model=self.args.hf_embedding_model,
model_name=self.args.hf_embedding_model,
pooling_strategy=self.args.pooling_strategy,
normalize_embeddings=self.args.normalize_embeddings,
hf_model_params=self.args.hf_model_params,
Expand All @@ -195,6 +195,9 @@ def initialize_embeddings_module(self):
hf_forward_params=self.args.hf_forward_params,
normalization_params=self.args.normalization_params,
query_instruction_for_retrieval=self.args.query_instruction_for_retrieval,
use_bedrock=self.args.embedding_use_bedrock,
bedrock_embedding_params=self.args.bedrock_embedding_params,
bedrock_aws_region=self.args.embedding_bedrock_aws_region,
)
logger.info("Embedding module initialized")

Expand Down
24 changes: 24 additions & 0 deletions src/agrag/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,30 @@ def embedding_batch_size(self):
def embedding_batch_size(self, value):
self.config["embedding"]["embedding_batch_size"] = value

@property
def embedding_use_bedrock(self):
return self.config.get("embedding", {}).get("use_bedrock", self.embedding_defaults.get("USE_BEDROCK"))

@embedding_use_bedrock.setter
def embedding_use_bedrock(self, value):
self.config["embedding"]["use_bedrock"] = value

@property
def bedrock_embedding_params(self):
return self.config.get("embedding", {}).get("bedrock_embedding_params", {})

@bedrock_embedding_params.setter
def bedrock_embedding_params(self, value):
self.config["embedding"]["bedrock_embedding_params"] = value

@property
def embedding_bedrock_aws_region(self):
return self.config.get("embedding", {}).get("bedrock_aws_region", {})

@embedding_bedrock_aws_region.setter
def embedding_bedrock_aws_region(self, value):
self.config["embedding"]["bedrock_aws_region"] = value

@property
def vector_db_type(self):
return self.config.get("vector_db", {}).get("db_type", self.vector_db_defaults.get("DB_TYPE"))
Expand Down
2 changes: 1 addition & 1 deletion src/agrag/configs/embedding/default.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
DEFAULT_EMBEDDING_MODEL: "BAAI/bge-large-en"
POOLING_STRATEGY: None
POOLING_STRATEGY: null
NORMALIZE_EMBEDDINGS: False
HF_TOKENIZER_PARAMS: {"truncation": True, "padding": True}
EMBEDDING_BATCH_SIZE: 128
5 changes: 3 additions & 2 deletions src/agrag/configs/presets/medium_quality_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ data:
- "txt"

embedding:
embedding_model: BAAI/bge-large-en
pooling_strategy: cls
embedding_model: cohere.embed-english-v3
use_bedrock: true
bedrock_aws_region: us-west-2
normalize_embeddings: false
hf_tokenizer_params:
truncation: true
Expand Down
9 changes: 8 additions & 1 deletion src/agrag/modules/embedding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Here are the configurable parameters for this module:

```
embedding:
embedding_model: The name of the Huggingface model to use for generating embeddings (default is "BAAI/bge-large-en").
embedding_model: The name of the Huggingface or Bedrock model to use for generating embeddings (default is "BAAI/bge-large-en" from Huggingface). Currently only Amazon Titan and Cohere Embedding models are supported on Bedrock.
pooling_strategy: The strategy to use for pooling embeddings. Options are 'mean', 'max', 'cls' (default is None).
Expand All @@ -23,4 +23,11 @@ embedding:
normalization_params: Additional parameters to pass to the PyTorch `nn. functional.normalize` method.
query_instruction_for_retrieval: Instruction for query when using embedding model.
use_bedrock: Whether to use the provided model from AWS Bedrock API. https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
Currently only Cohere (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html) and Amazon Titan (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan.html) embedding models are supported.
bedrock_embedding_params: Additional parameters to pass into the model when generating the embeddings.
bedrock_aws_region: AWS region where the model is hosted on Bedrock.
```
75 changes: 53 additions & 22 deletions src/agrag/modules/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
from typing import List, Union

import boto3
import numpy as np
import pandas as pd
import torch
Expand All @@ -9,7 +11,7 @@
from transformers import AutoModel, AutoTokenizer

from agrag.constants import DOC_TEXT_KEY, EMBEDDING_HIDDEN_DIM_KEY, EMBEDDING_KEY
from agrag.modules.embedding.utils import normalize_embedding, pool
from agrag.modules.embedding.utils import get_embeddings_bedrock, normalize_embedding, pool

logger = logging.getLogger("rag-logger")

Expand All @@ -20,8 +22,8 @@ class EmbeddingModule:
Attributes:
----------
hf_model : str
The name of the Huggingface model to use for generating embeddings (default is "BAAI/bge-large-en").
model_name : str
The name of the Huggingface or Bedrock model to use for generating embeddings (default is "BAAI/bge-large-en" from Huggingface).
pooling_strategy : str
The strategy to use for pooling embeddings. Options are 'mean', 'max', 'cls' (default is None).
normalize_embeddings: bool
Expand All @@ -38,6 +40,13 @@ class EmbeddingModule:
Additional parameters to pass to the PyTorch `nn.functional.normalize` method.
query_instruction_for_retrieval: str
Instruction for query when using embedding model.
use_bedrock: str
Whether to use the provided model from AWS Bedrock API. https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
Currently only Cohere (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html) and Amazon Titan (https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan.html) embedding models are supported.
bedrock_embedding_params: dict
Additional parameters to pass into the model when generating the embeddings.
bedrock_aws_region: str
AWS region where the model is hosted on Bedrock.
Methods:
-------
Expand All @@ -50,12 +59,12 @@ class EmbeddingModule:

def __init__(
self,
hf_model: str = "BAAI/bge-large-en",
model_name: str = "BAAI/bge-large-en",
pooling_strategy: str = None,
normalize_embeddings: bool = False,
**kwargs,
):
self.hf_model = hf_model
self.model_name = model_name
self.pooling_strategy = pooling_strategy
self.normalize_embeddings = normalize_embeddings
self.hf_model_params = kwargs.get("hf_model_params", {})
Expand All @@ -66,14 +75,25 @@ def __init__(
self.query_instruction_for_retrieval = kwargs.get("query_instruction_for_retrieval", None)
self.num_gpus = kwargs.get("num_gpus", 0)
self.device = "cpu" if not self.num_gpus else torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info(f"Using Huggingface Model {self.hf_model} for Embedding Module")
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model, **self.hf_tokenizer_init_params)
self.model = AutoModel.from_pretrained(self.hf_model, **self.hf_model_params)
if self.num_gpus > 1:
logger.info(f"Using {self.num_gpus} GPUs")
self.model = DataParallel(self.model)
self.model.to(self.device)
self.use_bedrock = kwargs.get("use_bedrock")
if self.use_bedrock:
if not "embed" in self.model_name:
raise ValueError(
f"Invalid model_id {self.model_name}. Must use an embedding model from Bedrock. The model_id should contain 'embed'."
)
logger.info(f"Using Bedrock Model {self.model_name} for Embedding Module")
self.bedrock_embedding_params = kwargs.get("bedrock_embedding_params", {})
if "cohere" in self.model_name:
self.bedrock_embedding_params["input_type"] = "search_document"
self.client = boto3.client("bedrock-runtime", region_name=kwargs.get("bedrock_aws_region", None))
else:
logger.info(f"Using Huggingface Model {self.model_name} for Embedding Module")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, **self.hf_tokenizer_init_params)
self.model = AutoModel.from_pretrained(self.model_name, **self.hf_model_params)
if self.num_gpus > 1:
logger.info(f"Using {self.num_gpus} GPUs")
self.model = DataParallel(self.model)
self.model.to(self.device)

def encode(self, data: pd.DataFrame, pbar: tqdm = None, batch_size: int = 32) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -112,24 +132,35 @@ def encode(self, data: pd.DataFrame, pbar: tqdm = None, batch_size: int = 32) ->

logger.info("\nTokenizing text chunks")
batch_texts = texts[i : i + batch_size]
inputs = self.tokenizer(batch_texts, return_tensors="pt", **self.hf_tokenizer_params)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

logger.info("\nGenerating embeddings")
with torch.no_grad():
outputs = self.model(**inputs, **self.hf_forward_params)
if self.use_bedrock:
batch_embeddings = get_embeddings_bedrock(
batch_texts=batch_texts,
client=self.client,
model_id=self.model_name,
embedding_params=self.bedrock_embedding_params,
)
else:
inputs = self.tokenizer(batch_texts, return_tensors="pt", **self.hf_tokenizer_params)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, **self.hf_forward_params)

logger.info("\nProcessing embeddings")
# The first element in the tuple returned by the model is the embeddings generated
# The tuple elements are (embeddings, hidden_states, past_key_values, attentions, cross_attentions)
batch_embeddings = outputs[0]

# The first element in the tuple returned by the model is the embeddings generated
# The tuple elements are (embeddings, hidden_states, past_key_values, attentions, cross_attentions)
batch_embeddings = outputs[0]
logger.info("\nProcessing embeddings")

batch_embeddings = pool(batch_embeddings, self.pooling_strategy)
if self.normalize_embeddings:
batch_embeddings = normalize_embedding(batch_embeddings, **self.normalization_params)

batch_embeddings = batch_embeddings.cpu().numpy()
if isinstance(batch_embeddings, torch.Tensor):
batch_embeddings = batch_embeddings.cpu().numpy()
else:
batch_embeddings = np.array(batch_embeddings)
all_embeddings.extend(batch_embeddings)
all_embeddings_hidden_dim.extend([batch_embeddings.shape[-1]] * batch_embeddings.shape[0])

Expand Down
74 changes: 73 additions & 1 deletion src/agrag/modules/embedding/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import List
import logging
from typing import Dict, List

import boto3
import torch
from torch.nn import functional as F

logger = logging.getLogger("rag-logger")
import json


def pool(embeddings: List[torch.Tensor], pooling_strategy: str) -> List[torch.Tensor]:
"""
Expand Down Expand Up @@ -73,3 +78,70 @@ def normalize_embedding(embeddings, args=None):
normalized_embeddings = normalize(embeddings, args)
"""
return F.normalize(embeddings, **args)


def get_embeddings_bedrock(
batch_texts: List[str], client: boto3.client, model_id: str, embedding_params: dict = {}
) -> List[float]:
embeddings = []
if "titan" in model_id:
for text in batch_texts:
body = json.dumps(
{
"inputText": text,
**embedding_params,
}
)
response = client.invoke_model(
body=body,
modelId=model_id,
accept="application/json",
contentType="application/json",
)
outputs = json.loads(response["body"].read())
embeddings.append(outputs.get("embedding"))
elif "cohere" in model_id:
body = json.dumps(
{
"texts": batch_texts,
**embedding_params,
}
)
response = client.invoke_model(
body=body,
modelId=model_id,
accept="application/json",
contentType="application/json",
)
outputs = json.loads(response["body"].read())
embeddings = outputs.get("embeddings")
else:
raise NotImplementedError(f"Unsupported Embedding Model for Bedrock {model_id}")
return embeddings


def extract_response(output: Dict) -> str:
"""
Extracts the response embeddings from the model output.
Parameters:
----------
output : Dict
The output dictionary from the Bedrock model.
Returns:
-------
str
The extracted response text.
"""
# Used for Mistral response
if "outputs" in output and isinstance(output["outputs"], list) and "embedding" in output["outputs"][0]:
return output["outputs"][0]["embedding"]
# Used for Anthropic response
elif "content" in output and output["type"] == "message":
return output["content"][0]["embedding"]
# Used for Llama response
elif "generation" in output:
return output["generation"]
else:
raise ValueError("Unknown output structure: %s", output)
3 changes: 1 addition & 2 deletions src/agrag/modules/generator/generators/bedrock_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,4 @@ def extract_response(self, output: Dict) -> str:
elif "generation" in output:
return output["generation"].strip()
else:
logger.error("Unknown output structure: %s", output)
return ""
raise ValueError("Unknown output structure: %s", output)
2 changes: 2 additions & 0 deletions src/agrag/modules/retriever/retrievers/retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def encode_query(self, query: str) -> np.ndarray:
np.ndarray
The embedding of the query.
"""
if self.embedding_module.use_bedrock and "cohere" in self.embedding_module.model_name:
self.embedding_module.bedrock_embedding_params["input_type"] = "search_query"
query_embedding = self.embedding_module.encode(data=pd.DataFrame([{DOC_TEXT_KEY: query}]))
query_embedding = query_embedding[EMBEDDING_KEY][0]
return query_embedding
Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def setUp(self, mock_model, mock_tokenizer):
mock_tokenizer.return_value = self.mock_tokenizer
mock_model.return_value = self.mock_model

hf_model_params = {"param": "param"}
model_name_params = {"param": "param"}
hf_tokenizer_params = {"param": True}
tokenizer_params = {"padding": 10, "max_length": 512}
forward_params = {"param": True}

self.embedding_module = EmbeddingModule(
hf_model="some-model",
model_name="some-model",
pooling_strategy=None,
hf_model_params=hf_model_params,
model_name_params=model_name_params,
hf_tokenizer_init_params=hf_tokenizer_params,
hf_tokenizer_params=tokenizer_params,
hf_forward_params=forward_params,
Expand Down

0 comments on commit 6a4ee7b

Please sign in to comment.