Skip to content

Commit

Permalink
Add rate limits to assistant chat messages (#3514)
Browse files Browse the repository at this point in the history
  • Loading branch information
olliestanley authored Jun 30, 2023
1 parent b9a6e99 commit 478aedf
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 10 deletions.
25 changes: 25 additions & 0 deletions inference/server/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import asyncio
import math
import signal
import sys

import fastapi
import redis.asyncio as redis
import sqlmodel
from fastapi.middleware.cors import CORSMiddleware
from fastapi_limiter import FastAPILimiter
from loguru import logger
from oasst_inference_server import database, deps, models, plugins
from oasst_inference_server.routes import account, admin, auth, chats, configs, workers
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.middleware.sessions import SessionMiddleware
from starlette.status import HTTP_429_TOO_MANY_REQUESTS

app = fastapi.FastAPI(title=settings.PROJECT_NAME)

Expand Down Expand Up @@ -81,6 +85,27 @@ async def alembic_upgrade():
signal.signal(signal.SIGINT, signal.SIG_DFL)


@app.on_event("startup")
async def setup_rate_limiter():
if not settings.rate_limit:
logger.warning("Skipping rate limiter setup on startup (rate_limit is False)")
return

async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
"""Error callback function when too many requests"""
expire = math.ceil(pexpire / 1000)
raise fastapi.HTTPException(f"Too Many Requests. Retry After {expire} seconds.", HTTP_429_TOO_MANY_REQUESTS)

try:
client = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_ratelim_db, decode_responses=True
)
logger.info(f"Connected to {client=}")
await FastAPILimiter.init(client, http_callback=http_callback)
except Exception:
logger.exception("Failed to establish Redis connection")


@app.on_event("startup")
async def maybe_add_debug_api_keys():
debug_api_keys = settings.debug_api_keys_list
Expand Down
26 changes: 17 additions & 9 deletions inference/server/oasst_inference_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@
trusted_client_scheme = APIKeyHeader(name="TrustedClient", auto_error=False, scheme_name="TrustedClient")


def get_current_user_id(
token: str = Security(authorization_scheme), trusted_client_token: str = Security(trusted_client_scheme)
) -> str:
"""Get the current user ID by decoding the JWT token."""
if trusted_client_token is not None:
info: auth.TrustedClient = auth.TrustedClientToken(content=trusted_client_token).content
if info.api_key not in settings.trusted_api_keys_list:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized client")
return info.user_id
def get_user_id_from_trusted_client_token(trusted_client_token: str) -> str:
info: auth.TrustedClient = auth.TrustedClientToken(content=trusted_client_token).content
if info.api_key not in settings.trusted_api_keys_list:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized client")
return info.user_id


def get_user_id_from_auth_token(token: str) -> str:
if token is None or not token.startswith("Bearer "):
logger.warning(f"Invalid token: {token}")
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
Expand All @@ -56,6 +54,16 @@ def get_current_user_id(
return user_id


def get_current_user_id(
token: str = Security(authorization_scheme), trusted_client_token: str = Security(trusted_client_scheme)
) -> str:
"""Get the current user ID."""
if trusted_client_token is not None:
return get_user_id_from_trusted_client_token(trusted_client_token)

return get_user_id_from_auth_token(token)


def create_access_token(user_id: str) -> str:
"""Create encoded JSON Web Token (JWT) for the given user ID."""
payload: bytes = build_payload(user_id, "access", settings.auth_access_token_expire_minutes)
Expand Down
19 changes: 19 additions & 0 deletions inference/server/oasst_inference_server/deps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import contextlib

import fastapi
import redis.asyncio as redis
from fastapi import Depends
from fastapi_limiter.depends import RateLimiter
from oasst_inference_server import auth
from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import AsyncSession, get_async_session
Expand Down Expand Up @@ -54,3 +56,20 @@ async def manual_chat_repository():
async def manual_user_chat_repository(user_id: str):
async with manual_create_session() as session:
yield await create_user_chat_repository(session, user_id)


async def user_identifier(request: fastapi.Request) -> str:
"""Identify a request by user based on auth header"""
trusted_client_token = request.headers.get("TrustedClient")
if trusted_client_token is not None:
return auth.get_user_id_from_trusted_client_token(trusted_client_token)

token = request.headers.get("Authorization")
return auth.get_user_id_from_auth_token(token)


class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
super().__init__(times, milliseconds, seconds, minutes, hours, user_identifier)
12 changes: 11 additions & 1 deletion inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ async def create_prompter_message(
return fastapi.Response(status_code=500)


@router.post("/{chat_id}/assistant_message")
@router.post(
"/{chat_id}/assistant_message",
dependencies=[
Depends(
deps.UserRateLimiter(
times=settings.rate_limit_messages_user_times,
seconds=settings.rate_limit_messages_user_seconds,
)
),
],
)
async def create_assistant_message(
chat_id: str,
request: chat_schema.CreateAssistantMessageRequest,
Expand Down
5 changes: 5 additions & 0 deletions inference/server/oasst_inference_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
redis_ratelim_db: int = 1

message_queue_expire: int = 60
work_queue_max_size: int | None = None

chat_max_messages: int | None = None
message_max_length: int | None = None

rate_limit: bool = True
rate_limit_messages_user_times: int = 20
rate_limit_messages_user_seconds: int = 600

allowed_worker_compat_hashes: str = "*"

@property
Expand Down
1 change: 1 addition & 0 deletions inference/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ asyncpg
authlib
beautifulsoup4 # web_retriever plugin
cryptography==39.0.0
fastapi-limiter
fastapi[all]==0.88.0
google-api-python-client
google-auth-httplib2
Expand Down

0 comments on commit 478aedf

Please sign in to comment.