Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bedrock client #213

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ index.faiss
*.svg
# ignore the softlink to adalflow cache
*.adalflow
.idea
3 changes: 2 additions & 1 deletion adalflow/adalflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from adalflow.optim.grad_component import GradComponent
from adalflow.core.generator import Generator


from adalflow.core.types import (
GeneratorOutput,
EmbedderOutput,
Expand Down Expand Up @@ -55,6 +54,7 @@
TransformersClient,
AnthropicAPIClient,
CohereAPIClient,
BedrockAPIClient,
)

__all__ = [
Expand Down Expand Up @@ -111,4 +111,5 @@
"TransformersClient",
"AnthropicAPIClient",
"CohereAPIClient",
"BedrockAPIClient",
]
6 changes: 5 additions & 1 deletion adalflow/adalflow/components/model_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
"adalflow.components.model_client.anthropic_client.AnthropicAPIClient",
OptionalPackages.ANTHROPIC,
)
BedrockAPIClient = LazyImport(
"adalflow.components.model_client.bedrock_client.BedrockAPIClient",
OptionalPackages.BEDROCK,
)
GroqAPIClient = LazyImport(
"adalflow.components.model_client.groq_client.GroqAPIClient",
OptionalPackages.GROQ,
Expand Down Expand Up @@ -61,14 +65,14 @@
OptionalPackages.OPENAI,
)


__all__ = [
"CohereAPIClient",
"TransformerReranker",
"TransformerEmbedder",
"TransformerLLM",
"TransformersClient",
"AnthropicAPIClient",
"BedrockAPIClient",
"GroqAPIClient",
"OpenAIClient",
"GoogleGenAIClient",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ async def acall(
elif model_type == ModelType.LLM:
return await self.async_client.messages.create(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")
raise ValueError(f"model_type {model_type} is not supported")
152 changes: 152 additions & 0 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""AWS Bedrock ModelClient integration."""

import os
from typing import Dict, Optional, Any, Callable
import backoff
import logging

from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput

import boto3
from botocore.config import Config

log = logging.getLogger(__name__)


def get_first_message_content(completion: Dict) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion['output']['message']['content'][0]['text']


__all__ = ["BedrockAPIClient", "get_first_message_content"]

# get the bedrock runtime exception
bedrock_runtime_exceptions = boto3.client("bedrock-runtime").exceptions


class BedrockAPIClient(ModelClient):
__doc__ = r"""A component wrapper for the Bedrock API client.
Visit https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html for more api details.
"""

def __init__(
self,
aws_profile_name=None,
aws_region_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
aws_connection_timeout=None,
aws_read_timeout=None,
chat_completion_parser: Callable = None,
):
super().__init__()
self._aws_profile_name = aws_profile_name
self._aws_region_name = aws_region_name
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._aws_session_token = aws_session_token
self._aws_connection_timeout = aws_connection_timeout
self._aws_read_timeout = aws_read_timeout

self.session = None
self.sync_client = self.init_sync_client()
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)

def init_sync_client(self):
"""
There is no need to pass both profile and secret key and access key. Path one of them.
if the compute power assume a role that have access to bedrock, no need to pass anything.
"""
aws_profile_name = self._aws_profile_name or os.getenv("AWS_PROFILE_NAME")
aws_region_name = self._aws_region_name or os.getenv("AWS_REGION_NAME")
aws_access_key_id = self._aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = self._aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
aws_session_token = self._aws_session_token or os.getenv("AWS_SESSION_TOKEN")

config = None
if self._aws_connection_timeout or self._aws_read_timeout:
config = Config(
connect_timeout=self._aws_connection_timeout, # Connection timeout in seconds
read_timeout=self._aws_read_timeout # Read timeout in seconds
)

session = boto3.Session(
profile_name=aws_profile_name,
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
bedrock_runtime = session.client(service_name="bedrock-runtime", config=config)
return bedrock_runtime

def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")

def parse_chat_completion(self, completion):
log.debug(f"completion: {completion}")
try:
data = completion['output']['message']['content'][0]['text']
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
except Exception as e:
log.error(f"Error parsing completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
)

def track_completion_usage(self, completion: Dict) -> CompletionUsage:
r"""Track the completion usage."""
usage = completion['usage']
return CompletionUsage(
completion_tokens=usage['outputTokens'],
prompt_tokens=usage['inputTokens'],
total_tokens=usage['totalTokens']
)

def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED
):
"""
check the converse api doc here:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
"""
api_kwargs = model_kwargs.copy()
if model_type == ModelType.LLM:
api_kwargs["messages"] = [
{"role": "user", "content": [{"text": input}]},
]
else:
raise ValueError(f"Model type {model_type} not supported")
return api_kwargs

@backoff.on_exception(
backoff.expo,
(
bedrock_runtime_exceptions.ThrottlingException,
bedrock_runtime_exceptions.ModelTimeoutException,
bedrock_runtime_exceptions.InternalServerException,
bedrock_runtime_exceptions.ModelErrorException,
bedrock_runtime_exceptions.ValidationException,
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.LLM:
return self.sync_client.converse(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")

async def acall(self):
raise NotImplementedError("Async call not implemented yet.")
4 changes: 2 additions & 2 deletions adalflow/adalflow/utils/lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from types import ModuleType


from enum import Enum

log = logging.getLogger(__name__)
Expand All @@ -19,6 +18,7 @@ class OptionalPackages(Enum):
GROQ = ("groq", "Please install groq with: pip install groq")
OPENAI = ("openai", "Please install openai with: pip install openai")
ANTHROPIC = ("anthropic", "Please install anthropic with: pip install anthropic")
BEDROCK = ("bedrock", "Please install boto3 with: pip install boto3")
GOOGLE_GENERATIVEAI = (
"google.generativeai",
"Please install google-generativeai with: pip install google-generativeai",
Expand Down Expand Up @@ -78,7 +78,7 @@ class LazyImport:
"""

def __init__(
self, import_path: str, optional_package: OptionalPackages, *args, **kwargs
self, import_path: str, optional_package: OptionalPackages, *args, **kwargs
):
if args or kwargs:
raise TypeError(
Expand Down
Loading