Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

fix: resolve global ratelimit being exceeded #679

Merged
merged 1 commit into from
Oct 18, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions naff/api/http/http_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""This file handles the interaction with discords http endpoints."""
import asyncio
import time
from typing import Any, cast
from urllib.parse import quote as _uriquote
from weakref import WeakValueDictionary

import aiohttp
import discord_typings
from aiohttp import BaseConnector, ClientSession, ClientWebSocketResponse, FormData
from multidict import CIMultiDictProxy

Expand Down Expand Up @@ -34,38 +36,49 @@
from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError
from naff.client.utils.input_utils import response_decode, OverriddenJson
from naff.client.utils.serializer import dict_filter
from naff.models import CooldownSystem
from naff.models.discord.file import UPLOADABLE_TYPE
from .route import Route
import discord_typings

__all__ = ("HTTPClient",)


class GlobalLock:
"""Manages the global ratelimit"""

def __init__(self) -> None:
self.cooldown_system: CooldownSystem = CooldownSystem(
45, 1
) # global rate-limit is 50 per second, conservatively we use 45
self._lock: asyncio.Lock = asyncio.Lock()
self._lock = asyncio.Lock()
self.max_requests = 45
self._calls = 0
self._reset_time = 0

async def rate_limit(self) -> None:
async with self._lock:
while not self.cooldown_system.acquire_token():
await asyncio.sleep(self.cooldown_system.get_cooldown_time())
@property
def calls_remaining(self) -> int:
"""Returns the amount of calls remaining."""
return self.max_requests - self._calls

def reset_calls(self) -> None:
"""Resets the calls to the max amount."""
self._calls = self.max_requests
self._reset_time = time.perf_counter() + 1

async def lock(self, delta: float) -> None:
def set_reset_time(self, delta: float) -> None:
"""
Lock the global lock for a given duration.
Sets the reset time to the current time + delta.

To be called if a 429 is received.
Args:
delta: The time to keep the lock acquired
delta: The time to wait before resetting the calls.
"""
await self._lock.acquire()
await asyncio.sleep(delta)
self._lock.release()
self._reset_time = time.perf_counter() + delta
self._calls = 0

async def wait(self) -> None:
"""Throttles calls to prevent hitting the global rate limit."""
async with self._lock:
if self._reset_time <= time.perf_counter():
self.reset_calls()
elif self._calls <= 0:
await asyncio.sleep(self._reset_time - time.perf_counter())
self.reset_calls()
self._calls -= 1


class BucketLock:
Expand Down Expand Up @@ -272,21 +285,17 @@ async def request(
for attempt in range(self._max_attempts):
async with lock:
try:
await self.global_lock.rate_limit()
# prevent us exceeding the global rate limit by throttling http requests

if cast(ClientSession, self.__session).closed:
if self.__session.closed:
await self.login(cast(str, self.token))

processed_data = self._process_payload(payload, files)
if isinstance(processed_data, FormData):
kwargs["data"] = processed_data # pyright: ignore
else:
kwargs["json"] = processed_data # pyright: ignore
await self.global_lock.wait()

async with cast(ClientSession, self.__session).request(
route.method, route.url, **kwargs
) as response:
async with self.__session.request(route.method, route.url, **kwargs) as response:
result = await response_decode(response)
self.ingest_ratelimit(route, response.headers, lock)

Expand All @@ -299,7 +308,7 @@ async def request(
logger.error(
f"Bot has exceeded global ratelimit, locking REST API for {result['retry_after']} seconds"
)
await self.global_lock.lock(float(result["retry_after"]))
self.global_lock.set_reset_time(float(result["retry_after"]))
continue
elif result.get("message") == "The resource is being rate limited.":
# resource ratelimit is reached
Expand Down Expand Up @@ -361,7 +370,7 @@ async def _raise_exception(self, response, route, result) -> None:

async def request_cdn(self, url, asset) -> bytes: # pyright: ignore [reportGeneralTypeIssues]
logger.debug(f"{asset} requests {url} from CDN")
async with cast(ClientSession, self.__session).get(url) as response:
async with self.__session.get(url) as response:
if response.status == 200:
return await response.read()
await self._raise_exception(response, asset, await response_decode(response))
Expand All @@ -377,7 +386,9 @@ async def login(self, token: str) -> dict[str, Any]:
The currently logged in bot's data

"""
self.__session = ClientSession(connector=self.connector)
self.__session = ClientSession(
connector=self.connector if self.connector else aiohttp.TCPConnector(limit=self.global_lock.max_requests),
)
self.token = token
try:
result = await self.request(Route("GET", "/users/@me"))
Expand Down Expand Up @@ -422,6 +433,6 @@ async def websocket_connect(self, url: str) -> ClientWebSocketResponse:
url: the url to connect to

"""
return await cast(ClientSession, self.__session).ws_connect(
return await self.__session.ws_connect(
url, timeout=30, max_msg_size=0, autoclose=False, headers={"User-Agent": self.user_agent}, compress=0
)