Skip to content

Commit

Permalink
fix: linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Jul 4, 2024
1 parent 17869fc commit 19c982d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 16 deletions.
21 changes: 15 additions & 6 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,22 @@ def post_call(
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "post_api_call"

verbose_logger.debug(
"RAW RESPONSE:\n{}\n\n".format(
self.model_call_details.get(
"original_response", self.model_call_details
if json_logs:
verbose_logger.debug(
"RAW RESPONSE:\n{}\n\n".format(
self.model_call_details.get(
"original_response", self.model_call_details
)
),
)
else:
print_verbose(
"RAW RESPONSE:\n{}\n\n".format(
self.model_call_details.get(
"original_response", self.model_call_details
)
)
),
)
)
if self.logger_fn and callable(self.logger_fn):
try:
self.logger_fn(
Expand Down
5 changes: 4 additions & 1 deletion litellm/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,16 @@ def completion(
optional_params["tools"] = anthropic_tools

stream = optional_params.pop("stream", None)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)

data = {
"model": model,
"messages": messages,
**optional_params,
}

if is_vertex_request is False:
data["model"] = model

## LOGGING
logging_obj.pre_call(
input=messages,
Expand Down
40 changes: 35 additions & 5 deletions litellm/llms/vertex_ai_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage

Expand Down Expand Up @@ -121,6 +122,17 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict):
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}

if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
Expand Down Expand Up @@ -177,17 +189,29 @@ def get_vertex_client(
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)

vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
access_token = _credentials.token
else:
vertex_ai_client = client
access_token = client.access_token

return vertex_ai_client, access_token


def create_vertex_anthropic_url(
vertex_location: str, vertex_project: str, model: str, stream: bool
) -> str:
if stream is True:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"


def completion(
model: str,
messages: list,
Expand All @@ -196,6 +220,8 @@ def completion(
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
Expand All @@ -207,6 +233,9 @@ def completion(
try:
import vertexai
from anthropic import AnthropicVertex

from litellm.llms.anthropic import AnthropicChatCompletion
from litellm.llms.vertex_httpx import VertexLLM
except:
raise VertexAIError(
status_code=400,
Expand All @@ -222,13 +251,14 @@ def completion(
)
try:

vertex_ai_client, access_token = get_vertex_client(
client=client,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
vertex_httpx_logic = VertexLLM()

access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)

anthropic_chat_completions = AnthropicChatCompletion()

## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
Expand Down
2 changes: 2 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,8 @@ def completion(
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
)
else:
model_response = vertex_ai.completion(
Expand Down
10 changes: 6 additions & 4 deletions litellm/tests/test_amazing_vertex_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,13 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}")


@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
@pytest.mark.parametrize(
"model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"]
) # "vertex_ai",
@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai",
@pytest.mark.asyncio
async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
async def test_gemini_pro_function_calling_httpx(model, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
Expand Down Expand Up @@ -682,7 +684,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
]

data = {
"model": "{}/gemini-1.5-pro".format(provider),
"model": model,
"messages": messages,
"tools": tools,
"tool_choice": "required",
Expand Down

0 comments on commit 19c982d

Please sign in to comment.