Skip to content

Commit

Permalink
LiteLLM Minor Fixes and Improvements (09/07/2024) (#5580)
Browse files Browse the repository at this point in the history
* fix(litellm_logging.py): set completion_start_time_float to end_time_float if none

Fixes #5500

* feat(_init_.py): add new 'openai_text_completion_compatible_providers' list

Fixes #5558

Handles correctly routing fireworks ai calls when done via text completions

* fix: fix linting errors

* fix: fix linting errors

* fix(openai.py): fix exception raised

* fix(openai.py): fix error handling

* fix(_redis.py): allow all supported arguments for redis cluster (#5554)

* Revert "fix(_redis.py): allow all supported arguments for redis cluster (#5554)" (#5583)

This reverts commit f2191ef.

* fix(router.py): return model alias w/ underlying deployment on router.get_model_list()

Fixes #5524 (comment)

* test: handle flaky tests

---------

Co-authored-by: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com>
  • Loading branch information
krrishdholakia and Kakadus authored Sep 10, 2024
1 parent c86b333 commit 4ac66bd
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 34 deletions.
7 changes: 6 additions & 1 deletion litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,12 @@ def identify(event_details):
"azure_ai",
"github",
]

openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions`
"together_ai",
"fireworks_ai",
]
)

# well supported replicate llms
replicate_models: List = [
Expand Down
2 changes: 2 additions & 0 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,6 +2329,8 @@ def get_standard_logging_object_payload(
completion_start_time_float = completion_start_time.timestamp()
elif isinstance(completion_start_time, float):
completion_start_time_float = completion_start_time
else:
completion_start_time_float = end_time_float
# clean up litellm hidden params
clean_hidden_params = StandardLoggingHiddenParams(
model_id=None,
Expand Down
4 changes: 2 additions & 2 deletions litellm/llms/OpenAI/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,7 @@ async def async_streaming(

error_headers = getattr(e, "headers", None)
if response is not None and hasattr(response, "text"):
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
Expand Down Expand Up @@ -1800,12 +1801,11 @@ def completion(
headers: Optional[dict] = None,
):
super().completion()
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
raise OpenAIError(status_code=422, message="Missing model or messages")

if (
len(messages) > 0
Expand Down
14 changes: 9 additions & 5 deletions litellm/llms/azure_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,10 @@ def completion(
client=None,
):
super().completion()
exception_mapping_worked = False
try:
if model is None or messages is None:
raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
status_code=422, message="Missing model or messages"
)

max_retries = optional_params.pop("max_retries", 2)
Expand Down Expand Up @@ -293,7 +292,10 @@ def completion(
"api-version", api_version
)

response = azure_client.completions.create(**data, timeout=timeout) # type: ignore
raw_response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
response = raw_response.parse()
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
Expand Down Expand Up @@ -380,13 +382,15 @@ async def acompletion(
"complete_input_dict": data,
},
)
response = await azure_client.completions.create(**data, timeout=timeout)
raw_response = await azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
response = raw_response.parse()
return openai_text_completion_config.convert_to_chat_model_response_object(
response_object=response.model_dump(),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
Expand Down
7 changes: 5 additions & 2 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,9 @@ def completion(
custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
or custom_llm_provider
in litellm.openai_text_completion_compatible_providers
and kwargs.get("text_completion") is True
):
openai.api_type = "openai"

Expand Down Expand Up @@ -4099,8 +4102,8 @@ def process_prompt(i, individual_prompt):

kwargs.pop("prompt", None)

if (
_model is not None and custom_llm_provider == "openai"
if _model is not None and (
custom_llm_provider == "openai"
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
if _model not in litellm.open_ai_chat_completion_models:
model = "text-completion-openai/" + _model
Expand Down
19 changes: 6 additions & 13 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
model_list:
- model_name: "anthropic/claude-3-5-sonnet-20240620"
- model_name: "gpt-turbo"
litellm_params:
model: anthropic/claude-3-5-sonnet-20240620
# api_base: http://0.0.0.0:9000
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/*
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE

litellm_settings:
success_callback: ["s3"]
s3_callback_params:
s3_bucket_name: litellm-logs # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
router_settings:
model_group_alias: {"gpt-4": "gpt-turbo"}
24 changes: 23 additions & 1 deletion litellm/proxy/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
import random
from typing import Optional
from typing import List, Optional

import litellm
from litellm._logging import print_verbose
Expand Down Expand Up @@ -36,6 +36,25 @@ def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
)


def filter_deployments_by_id(
model_list: List,
) -> List:
seen_ids = set()
filtered_deployments = []

for deployment in model_list:
_model_info = deployment.get("model_info") or {}
_id = _model_info.get("id") or None
if _id is None:
continue

if _id not in seen_ids:
seen_ids.add(_id)
filtered_deployments.append(deployment)

return filtered_deployments


async def _perform_health_check(model_list: list, details: Optional[bool] = True):
"""
Perform a health check for each model in the list.
Expand Down Expand Up @@ -105,6 +124,9 @@ async def perform_health_check(
_new_model_list = [x for x in model_list if x["model_name"] == model]
model_list = _new_model_list

model_list = filter_deployments_by_id(
model_list=model_list
) # filter duplicate deployments (e.g. when model alias'es are used)
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
model_list, details
)
Expand Down
4 changes: 2 additions & 2 deletions litellm/proxy/management_helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ async def add_new_member(
where={"user_id": user_info.user_id}, # type: ignore
data={"teams": {"push": [team_id]}},
)

returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
if _returned_user is not None:
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
elif len(existing_user_row) > 1:
raise HTTPException(
status_code=400,
Expand Down
35 changes: 29 additions & 6 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4556,6 +4556,27 @@ def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
ids.append(id)
return ids

def _get_all_deployments(
self, model_name: str, model_alias: Optional[str] = None
) -> List[DeploymentTypedDict]:
"""
Return all deployments of a model name
Used for accurate 'get_model_list'.
"""

returned_models: List[DeploymentTypedDict] = []
for model in self.model_list:
if model["model_name"] == model_name:
if model_alias is not None:
alias_model = copy.deepcopy(model)
alias_model["model_name"] = model_name
returned_models.append(alias_model)
else:
returned_models.append(model)

return returned_models

def get_model_names(self) -> List[str]:
"""
Returns all possible model names for router.
Expand All @@ -4567,24 +4588,26 @@ def get_model_names(self) -> List[str]:
def get_model_list(
self, model_name: Optional[str] = None
) -> Optional[List[DeploymentTypedDict]]:
"""
Includes router model_group_alias'es as well
"""
if hasattr(self, "model_list"):
returned_models: List[DeploymentTypedDict] = []

for model_alias, model_value in self.model_group_alias.items():
model_alias_item = DeploymentTypedDict(
model_name=model_alias,
litellm_params=LiteLLMParamsTypedDict(model=model_value),
returned_models.extend(
self._get_all_deployments(
model_name=model_value, model_alias=model_alias
)
)
returned_models.append(model_alias_item)

if model_name is None:
returned_models += self.model_list

return returned_models

for model in self.model_list:
if model["model_name"] == model_name:
returned_models.append(model)
returned_models.extend(self._get_all_deployments(model_name=model_name))

return returned_models
return None
Expand Down
2 changes: 2 additions & 0 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,8 @@ async def test_model_function_invoke(model, sync_mode, api_key, api_base):
response = await litellm.acompletion(**data)

print(f"response: {response}")
except litellm.InternalServerError:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions litellm/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _pre_call_utils(
data["messages"] = [{"role": "user", "content": "Hello world"}]
if streaming is True:
data["stream"] = True
mapped_target = client.chat.completions.with_raw_response
mapped_target = client.chat.completions.with_raw_response # type: ignore
if sync_mode:
original_function = litellm.completion
else:
Expand All @@ -873,7 +873,7 @@ def _pre_call_utils(
data["prompt"] = "Hello world"
if streaming is True:
data["stream"] = True
mapped_target = client.completions.with_raw_response
mapped_target = client.completions.with_raw_response # type: ignore
if sync_mode:
original_function = litellm.text_completion
else:
Expand Down
1 change: 1 addition & 0 deletions litellm/tests/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_current_weather(location, unit="fahrenheit"):
# "anthropic.claude-3-sonnet-20240229-v1:0",
],
)
@pytest.mark.flaky(retries=3, delay=1)
def test_aaparallel_function_call(model):
try:
litellm.set_verbose = True
Expand Down
11 changes: 11 additions & 0 deletions litellm/tests/test_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4239,3 +4239,14 @@ def test_completion_vllm():
mock_call.assert_called_once()

assert "hello" in mock_call.call_args.kwargs["extra_body"]


def test_completion_fireworks_ai_multiple_choices():
litellm.set_verbose = True
response = litellm.text_completion(
model="fireworks_ai/llama-v3p1-8b-instruct",
prompt=["halo", "hi", "halo", "hi"],
)
print(response.choices)

assert len(response.choices) == 4
1 change: 1 addition & 0 deletions proxy_server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ router_settings:
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: true
model_group_alias: {"my-special-fake-model-alias-name": "fake-openai-endpoint-3"}

general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
Expand Down

0 comments on commit 4ac66bd

Please sign in to comment.