From 3bca279b35fbd5f014c30c3356e8e70ec3493691 Mon Sep 17 00:00:00 2001 From: Aarushi <50577581+aarushik93@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:25:29 +0000 Subject: [PATCH] feat(libs): Add API key rate limit middleware (#8850) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Once we release api key feature, we will want to be able to rate limit as well. This is the foundation for that. For now it is a blanket rate limit, later we will be able to add tiered rate limits ### Changes 🏗️ Added new middleware libary in autogpt_libs which contains the logic for getting the api key, storing it's details in redis and checking how many requests it's done, how many are left and what the reset time is. --------- Co-authored-by: Zamil Majdy Co-authored-by: Reinier van der Leer --- .../autogpt_libs/feature_flag/client.py | 2 +- .../autogpt_libs/logging/config.py | 1 - .../autogpt_libs/logging/test_utils.py | 6 +-- .../autogpt_libs/rate_limit/__init__.py | 0 .../autogpt_libs/rate_limit/config.py | 31 +++++++++++ .../autogpt_libs/rate_limit/limiter.py | 51 +++++++++++++++++++ .../autogpt_libs/rate_limit/middleware.py | 32 ++++++++++++ 7 files changed, 118 insertions(+), 5 deletions(-) create mode 100644 autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/__init__.py create mode 100644 autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/config.py create mode 100644 autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/limiter.py create mode 100644 autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/middleware.py diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py b/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py index d9f081e5cb4b..dde516c1d8fa 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py @@ -72,7 +72,7 @@ def feature_flag( """ def decorator( - func: Callable[P, Union[T, Awaitable[T]]] + func: Callable[P, Union[T, Awaitable[T]]], ) -> Callable[P, Union[T, Awaitable[T]]]: @wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/logging/config.py b/autogpt_platform/autogpt_libs/autogpt_libs/logging/config.py index 523f6cf8ec87..10db444247c6 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/logging/config.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/logging/config.py @@ -23,7 +23,6 @@ class LoggingConfig(BaseSettings): - level: str = Field( default="INFO", description="Logging level", diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/logging/test_utils.py b/autogpt_platform/autogpt_libs/autogpt_libs/logging/test_utils.py index b3682d42cf0d..24e20986549c 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/logging/test_utils.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/logging/test_utils.py @@ -24,10 +24,10 @@ ), ("", ""), ("hello", "hello"), - ("hello\x1B[31m world", "hello world"), - ("\x1B[36mHello,\x1B[32m World!", "Hello, World!"), + ("hello\x1b[31m world", "hello world"), + ("\x1b[36mHello,\x1b[32m World!", "Hello, World!"), ( - "\x1B[1m\x1B[31mError:\x1B[0m\x1B[31m file not found", + "\x1b[1m\x1b[31mError:\x1b[0m\x1b[31m file not found", "Error: file not found", ), ], diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/__init__.py b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/config.py b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/config.py new file mode 100644 index 000000000000..76c9abaa0729 --- /dev/null +++ b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/config.py @@ -0,0 +1,31 @@ +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class RateLimitSettings(BaseSettings): + redis_host: str = Field( + default="redis://localhost:6379", + description="Redis host", + validation_alias="REDIS_HOST", + ) + + redis_port: str = Field( + default="6379", description="Redis port", validation_alias="REDIS_PORT" + ) + + redis_password: str = Field( + default="password", + description="Redis password", + validation_alias="REDIS_PASSWORD", + ) + + requests_per_minute: int = Field( + default=60, + description="Maximum number of requests allowed per minute per API key", + validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE", + ) + + model_config = SettingsConfigDict(case_sensitive=True, extra="ignore") + + +RATE_LIMIT_SETTINGS = RateLimitSettings() diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/limiter.py b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/limiter.py new file mode 100644 index 000000000000..efad05836f4f --- /dev/null +++ b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/limiter.py @@ -0,0 +1,51 @@ +import time +from typing import Tuple + +from redis import Redis + +from .config import RATE_LIMIT_SETTINGS + + +class RateLimiter: + def __init__( + self, + redis_host: str = RATE_LIMIT_SETTINGS.redis_host, + redis_port: str = RATE_LIMIT_SETTINGS.redis_port, + redis_password: str = RATE_LIMIT_SETTINGS.redis_password, + requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute, + ): + self.redis = Redis( + host=redis_host, + port=redis_port, + password=redis_password, + decode_responses=True, + ) + self.window = 60 + self.max_requests = requests_per_minute + + async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]: + """ + Check if request is within rate limits. + + Args: + api_key_id: The API key identifier to check + + Returns: + Tuple of (is_allowed, remaining_requests, reset_time) + """ + now = time.time() + window_start = now - self.window + key = f"ratelimit:{api_key_id}:1min" + + pipe = self.redis.pipeline() + pipe.zremrangebyscore(key, 0, window_start) + pipe.zadd(key, {str(now): now}) + pipe.zcount(key, window_start, now) + pipe.expire(key, self.window) + + _, _, request_count, _ = pipe.execute() + + remaining = max(0, self.max_requests - request_count) + reset_time = int(now + self.window) + + return request_count <= self.max_requests, remaining, reset_time diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/middleware.py b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/middleware.py new file mode 100644 index 000000000000..496697d8b1e2 --- /dev/null +++ b/autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/middleware.py @@ -0,0 +1,32 @@ +from fastapi import HTTPException, Request +from starlette.middleware.base import RequestResponseEndpoint + +from .limiter import RateLimiter + + +async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint): + """FastAPI middleware for rate limiting API requests.""" + limiter = RateLimiter() + + if not request.url.path.startswith("/api"): + return await call_next(request) + + api_key = request.headers.get("Authorization") + if not api_key: + return await call_next(request) + + api_key = api_key.replace("Bearer ", "") + + is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key) + + if not is_allowed: + raise HTTPException( + status_code=429, detail="Rate limit exceeded. Please try again later." + ) + + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limiter.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Reset"] = str(reset_time) + + return response