Skip to content

Commit

Permalink
Merge branch 'main' into litellm_webhook_support
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia authored May 21, 2024
2 parents 9d815be + ad91bff commit 707cf24
Show file tree
Hide file tree
Showing 19 changed files with 830 additions and 88 deletions.
16 changes: 16 additions & 0 deletions docs/my-website/docs/image_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,20 @@ response = image_generation(
model="bedrock/stability.stable-diffusion-xl-v0",
)
print(f"response: {response}")
```

## VertexAI - Image Generation Models

### Usage

Use this for image generation models on VertexAI

```python
response = litellm.image_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
print(f"response: {response}")
```
25 changes: 25 additions & 0 deletions docs/my-website/docs/providers/vertex.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |

## Image Generation Models

Usage

```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
```

**Generating multiple images**

Use the `n` parameter to pass how many images you want generated
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
```

## Extra

Expand Down
143 changes: 129 additions & 14 deletions docs/my-website/docs/proxy/call_hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,45 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
def __init__(self):
pass

#### ASYNC ####

async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
#### CALL HOOKS - proxy only ####

async def async_log_pre_api_call(self, model, messages, kwargs):
pass
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
data["model"] = "my-new-model"
return data

async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass

async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
pass

#### CALL HOOKS - proxy only ####

async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
data["model"] = "my-new-model"
return data
async def async_moderation_hook( # call made in parallel to llm api call
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
pass

async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
pass
proxy_handler_instance = MyCustomHandler()
```

Expand Down Expand Up @@ -190,4 +209,100 @@ general_settings:
**Result**
<Image img={require('../../img/end_user_enforcement.png')}/>
<Image img={require('../../img/end_user_enforcement.png')}/>
## Advanced - Return rejected message as response
For chat completions and text completion calls, you can return a rejected message as a user response.
Do this by returning a string. LiteLLM takes care of returning the response in the correct format depending on the endpoint and if it's streaming/non-streaming.
For non-chat/text completion endpoints, this response is returned as a 400 status code exception.
### 1. Create Custom Handler
```python
from litellm.integrations.custom_logger import CustomLogger
import litellm
from litellm.utils import get_formatted_prompt

# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
def __init__(self):
pass

#### CALL HOOKS - proxy only ####

async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)

if "Hello world" in formatted_prompt:
return "This is an invalid response"

return data

proxy_handler_instance = MyCustomHandler()
```

### 2. Update config.yaml

```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo

litellm_settings:
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
```
### 3. Test it!
```shell
$ litellm /path/to/config.yaml
```
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello world"
}
],
}'
```

**Expected Response**

```
{
"id": "chatcmpl-d00bbede-2d90-4618-bf7b-11a1c23cf360",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "This is an invalid response.", # 👈 REJECTED RESPONSE
"role": "assistant"
}
}
],
"created": 1716234198,
"model": null,
"object": "chat.completion",
"system_fingerprint": null,
"usage": {}
}
```
3 changes: 3 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def identify(event_details):
get_supported_openai_params,
get_api_base,
get_first_chars_messages,
ModelResponse,
ImageResponse,
ImageObject,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
Expand Down
26 changes: 26 additions & 0 deletions litellm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,32 @@ def __init__(
) # Call the base class constructor with the parameters it needs


# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
class RejectedRequestError(BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
request_data: dict,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.request_data = request_data
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
) # Call the base class constructor with the parameters it needs


class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(
Expand Down
14 changes: 11 additions & 3 deletions litellm/integrations/custom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache

from typing import Literal, Union, Optional
import traceback

Expand Down Expand Up @@ -64,8 +63,17 @@ async def async_pre_call_hook(
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings", "image_generation"],
):
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass

async def async_post_call_failure_hook(
Expand Down
48 changes: 29 additions & 19 deletions litellm/integrations/slack_alerting.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,27 +871,37 @@ async def send_alert(

async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency"""
if "daily_reports" in self.alert_types:
model_id = (
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
response_s: timedelta = end_time - start_time

final_value = response_s
total_tokens = 0

if isinstance(response_obj, litellm.ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
final_value = float(response_s.total_seconds() / completion_tokens)

await self.async_update_daily_reports(
DeploymentMetrics(
id=model_id,
failed_request=False,
latency_per_output_token=final_value,
updated_at=litellm.utils.get_utc_datetime(),
try:
if "daily_reports" in self.alert_types:
model_id = (
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
response_s: timedelta = end_time - start_time

final_value = response_s
total_tokens = 0

if isinstance(response_obj, litellm.ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
if completion_tokens is not None and completion_tokens > 0:
final_value = float(
response_s.total_seconds() / completion_tokens
)

await self.async_update_daily_reports(
DeploymentMetrics(
id=model_id,
failed_request=False,
latency_per_output_token=final_value,
updated_at=litellm.utils.get_utc_datetime(),
)
)
except Exception as e:
verbose_proxy_logger.error(
"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: ",
e,
)
pass

async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency"""
Expand Down
6 changes: 3 additions & 3 deletions litellm/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
Expand Down
Loading

0 comments on commit 707cf24

Please sign in to comment.