Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rate limits to assistant chat messages #3514

Merged
merged 4 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions inference/server/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import asyncio
import math
import signal
import sys

import fastapi
import redis.asyncio as redis
import sqlmodel
from fastapi.middleware.cors import CORSMiddleware
from fastapi_limiter import FastAPILimiter
from loguru import logger
from oasst_inference_server import database, deps, models, plugins
from oasst_inference_server.routes import account, admin, auth, chats, configs, workers
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.middleware.sessions import SessionMiddleware
from starlette.status import HTTP_429_TOO_MANY_REQUESTS

app = fastapi.FastAPI(title=settings.PROJECT_NAME)

Expand Down Expand Up @@ -81,6 +85,27 @@ async def alembic_upgrade():
signal.signal(signal.SIGINT, signal.SIG_DFL)


@app.on_event("startup")
async def setup_rate_limiter():
if not settings.rate_limit:
logger.warning("Skipping rate limiter setup on startup (rate_limit is False)")
return

async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
"""Error callback function when too many requests"""
expire = math.ceil(pexpire / 1000)
raise fastapi.HTTPException(f"Too Many Requests. Retry After {expire} seconds.", HTTP_429_TOO_MANY_REQUESTS)

try:
client = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_ratelim_db, decode_responses=True
)
logger.info(f"Connected to {client=}")
await FastAPILimiter.init(client, http_callback=http_callback)
except Exception:
logger.exception("Failed to establish Redis connection")


@app.on_event("startup")
async def maybe_add_debug_api_keys():
debug_api_keys = settings.debug_api_keys_list
Expand Down
26 changes: 17 additions & 9 deletions inference/server/oasst_inference_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@
trusted_client_scheme = APIKeyHeader(name="TrustedClient", auto_error=False, scheme_name="TrustedClient")


def get_current_user_id(
token: str = Security(authorization_scheme), trusted_client_token: str = Security(trusted_client_scheme)
) -> str:
"""Get the current user ID by decoding the JWT token."""
if trusted_client_token is not None:
info: auth.TrustedClient = auth.TrustedClientToken(content=trusted_client_token).content
if info.api_key not in settings.trusted_api_keys_list:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized client")
return info.user_id
def get_user_id_from_trusted_client_token(trusted_client_token: str) -> str:
info: auth.TrustedClient = auth.TrustedClientToken(content=trusted_client_token).content
if info.api_key not in settings.trusted_api_keys_list:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized client")
return info.user_id


def get_user_id_from_auth_token(token: str) -> str:
if token is None or not token.startswith("Bearer "):
logger.warning(f"Invalid token: {token}")
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
Expand All @@ -56,6 +54,16 @@ def get_current_user_id(
return user_id


def get_current_user_id(
token: str = Security(authorization_scheme), trusted_client_token: str = Security(trusted_client_scheme)
) -> str:
"""Get the current user ID."""
if trusted_client_token is not None:
return get_user_id_from_trusted_client_token(trusted_client_token)

return get_user_id_from_auth_token(token)


def create_access_token(user_id: str) -> str:
"""Create encoded JSON Web Token (JWT) for the given user ID."""
payload: bytes = build_payload(user_id, "access", settings.auth_access_token_expire_minutes)
Expand Down
19 changes: 19 additions & 0 deletions inference/server/oasst_inference_server/deps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import contextlib

import fastapi
import redis.asyncio as redis
from fastapi import Depends
from fastapi_limiter.depends import RateLimiter
from oasst_inference_server import auth
from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import AsyncSession, get_async_session
Expand Down Expand Up @@ -54,3 +56,20 @@ async def manual_chat_repository():
async def manual_user_chat_repository(user_id: str):
async with manual_create_session() as session:
yield await create_user_chat_repository(session, user_id)


async def user_identifier(request: fastapi.Request) -> str:
"""Identify a request by user based on auth header"""
trusted_client_token = request.headers.get("TrustedClient")
if trusted_client_token is not None:
return auth.get_user_id_from_trusted_client_token(trusted_client_token)

token = request.headers.get("Authorization")
return auth.get_user_id_from_auth_token(token)


class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
super().__init__(times, milliseconds, seconds, minutes, hours, user_identifier)
12 changes: 11 additions & 1 deletion inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ async def create_prompter_message(
return fastapi.Response(status_code=500)


@router.post("/{chat_id}/assistant_message")
@router.post(
"/{chat_id}/assistant_message",
dependencies=[
Depends(
deps.UserRateLimiter(
times=settings.rate_limit_messages_user_times,
seconds=settings.rate_limit_messages_user_seconds,
)
),
],
)
async def create_assistant_message(
chat_id: str,
request: chat_schema.CreateAssistantMessageRequest,
Expand Down
5 changes: 5 additions & 0 deletions inference/server/oasst_inference_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
redis_ratelim_db: int = 1

message_queue_expire: int = 60
work_queue_max_size: int | None = None

chat_max_messages: int | None = None
message_max_length: int | None = None

rate_limit: bool = True
rate_limit_messages_user_times: int = 20
rate_limit_messages_user_seconds: int = 600

allowed_worker_compat_hashes: str = "*"

@property
Expand Down
1 change: 1 addition & 0 deletions inference/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ asyncpg
authlib
beautifulsoup4 # web_retriever plugin
cryptography==39.0.0
fastapi-limiter
fastapi[all]==0.88.0
google-api-python-client
google-auth-httplib2
Expand Down