Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
fix: resolve global ratelimit being exceeded (#679)
Browse files Browse the repository at this point in the history
  • Loading branch information
LordOfPolls authored Oct 18, 2022
1 parent 1237435 commit 155325e
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions naff/api/http/http_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""This file handles the interaction with discords http endpoints."""
import asyncio
import time
from typing import Any, cast
from urllib.parse import quote as _uriquote
from weakref import WeakValueDictionary

import aiohttp
import discord_typings
from aiohttp import BaseConnector, ClientSession, ClientWebSocketResponse, FormData
from multidict import CIMultiDictProxy

Expand Down Expand Up @@ -34,38 +36,49 @@
from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError
from naff.client.utils.input_utils import response_decode, OverriddenJson
from naff.client.utils.serializer import dict_filter
from naff.models import CooldownSystem
from naff.models.discord.file import UPLOADABLE_TYPE
from .route import Route
import discord_typings

__all__ = ("HTTPClient",)


class GlobalLock:
"""Manages the global ratelimit"""

def __init__(self) -> None:
self.cooldown_system: CooldownSystem = CooldownSystem(
45, 1
) # global rate-limit is 50 per second, conservatively we use 45
self._lock: asyncio.Lock = asyncio.Lock()
self._lock = asyncio.Lock()
self.max_requests = 45
self._calls = 0
self._reset_time = 0

async def rate_limit(self) -> None:
async with self._lock:
while not self.cooldown_system.acquire_token():
await asyncio.sleep(self.cooldown_system.get_cooldown_time())
@property
def calls_remaining(self) -> int:
"""Returns the amount of calls remaining."""
return self.max_requests - self._calls

def reset_calls(self) -> None:
"""Resets the calls to the max amount."""
self._calls = self.max_requests
self._reset_time = time.perf_counter() + 1

async def lock(self, delta: float) -> None:
def set_reset_time(self, delta: float) -> None:
"""
Lock the global lock for a given duration.
Sets the reset time to the current time + delta.
To be called if a 429 is received.
Args:
delta: The time to keep the lock acquired
delta: The time to wait before resetting the calls.
"""
await self._lock.acquire()
await asyncio.sleep(delta)
self._lock.release()
self._reset_time = time.perf_counter() + delta
self._calls = 0

async def wait(self) -> None:
"""Throttles calls to prevent hitting the global rate limit."""
async with self._lock:
if self._reset_time <= time.perf_counter():
self.reset_calls()
elif self._calls <= 0:
await asyncio.sleep(self._reset_time - time.perf_counter())
self.reset_calls()
self._calls -= 1


class BucketLock:
Expand Down Expand Up @@ -272,21 +285,17 @@ async def request(
for attempt in range(self._max_attempts):
async with lock:
try:
await self.global_lock.rate_limit()
# prevent us exceeding the global rate limit by throttling http requests

if cast(ClientSession, self.__session).closed:
if self.__session.closed:
await self.login(cast(str, self.token))

processed_data = self._process_payload(payload, files)
if isinstance(processed_data, FormData):
kwargs["data"] = processed_data # pyright: ignore
else:
kwargs["json"] = processed_data # pyright: ignore
await self.global_lock.wait()

async with cast(ClientSession, self.__session).request(
route.method, route.url, **kwargs
) as response:
async with self.__session.request(route.method, route.url, **kwargs) as response:
result = await response_decode(response)
self.ingest_ratelimit(route, response.headers, lock)

Expand All @@ -299,7 +308,7 @@ async def request(
logger.error(
f"Bot has exceeded global ratelimit, locking REST API for {result['retry_after']} seconds"
)
await self.global_lock.lock(float(result["retry_after"]))
self.global_lock.set_reset_time(float(result["retry_after"]))
continue
elif result.get("message") == "The resource is being rate limited.":
# resource ratelimit is reached
Expand Down Expand Up @@ -361,7 +370,7 @@ async def _raise_exception(self, response, route, result) -> None:

async def request_cdn(self, url, asset) -> bytes: # pyright: ignore [reportGeneralTypeIssues]
logger.debug(f"{asset} requests {url} from CDN")
async with cast(ClientSession, self.__session).get(url) as response:
async with self.__session.get(url) as response:
if response.status == 200:
return await response.read()
await self._raise_exception(response, asset, await response_decode(response))
Expand All @@ -377,7 +386,9 @@ async def login(self, token: str) -> dict[str, Any]:
The currently logged in bot's data
"""
self.__session = ClientSession(connector=self.connector)
self.__session = ClientSession(
connector=self.connector if self.connector else aiohttp.TCPConnector(limit=self.global_lock.max_requests),
)
self.token = token
try:
result = await self.request(Route("GET", "/users/@me"))
Expand Down Expand Up @@ -422,6 +433,6 @@ async def websocket_connect(self, url: str) -> ClientWebSocketResponse:
url: the url to connect to
"""
return await cast(ClientSession, self.__session).ws_connect(
return await self.__session.ws_connect(
url, timeout=30, max_msg_size=0, autoclose=False, headers={"User-Agent": self.user_agent}, compress=0
)

0 comments on commit 155325e

Please sign in to comment.