Skip to content

Commit

Permalink
feat: add exponential and constant backoff function
Browse files Browse the repository at this point in the history
  • Loading branch information
jooola committed Jul 19, 2024
1 parent 49d5d4f commit 638da80
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 5 deletions.
6 changes: 5 additions & 1 deletion hcloud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

from ._client import Client as Client # noqa pylint: disable=C0414
from ._client import ( # noqa pylint: disable=C0414
Client as Client,
constant_backoff_function as constant_backoff_function,
exponential_backoff_function as exponential_backoff_function,
)
from ._exceptions import ( # noqa pylint: disable=C0414
APIException as APIException,
HCloudException as HCloudException,
Expand Down
48 changes: 45 additions & 3 deletions hcloud/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import time
from random import uniform
from typing import Protocol

import requests
Expand All @@ -26,7 +27,7 @@
from .volumes import VolumesClient


class PollIntervalFunction(Protocol):
class BackoffFunction(Protocol):
def __call__(self, retries: int) -> float:
"""
Return a interval in seconds to wait between each API call.
Expand All @@ -35,6 +36,47 @@ def __call__(self, retries: int) -> float:
"""


def constant_backoff_function(interval: float) -> BackoffFunction:
"""
Return a backoff function, implementing a constant backoff.
:param interval: Constant interval to return.
"""

# pylint: disable=unused-argument
def func(retries: int) -> float:
return interval

return func


def exponential_backoff_function(
*,
base: float,
multiplier: int,
cap: float,
jitter: bool = False,
) -> BackoffFunction:
"""
Return a backoff function, implementing a truncated exponential backoff with
optional full jitter.
:param base: Base for the exponential backoff algorithm.
:param multiplier: Multiplier for the exponential backoff algorithm.
:param cap: Value at which the interval is truncated.
:param jitter: Whether to add jitter.
"""

def func(retries: int) -> float:
interval = base * multiplier**retries # Exponential backoff
interval = min(cap, interval) # Cap backoff
if jitter:
interval = uniform(base, interval) # Add jitter
return interval

return func


class Client:
"""Base Client for accessing the Hetzner Cloud API"""

Expand All @@ -48,7 +90,7 @@ def __init__(
api_endpoint: str = "https://api.hetzner.cloud/v1",
application_name: str | None = None,
application_version: str | None = None,
poll_interval: int | float | PollIntervalFunction = 1.0,
poll_interval: int | float | BackoffFunction = 1.0,
poll_max_retries: int = 120,
timeout: float | tuple[float, float] | None = None,
):
Expand All @@ -73,7 +115,7 @@ def __init__(
self._requests_timeout = timeout

if isinstance(poll_interval, (int, float)):
self._poll_interval_func = lambda _: poll_interval # Constant poll interval
self._poll_interval_func = constant_backoff_function(poll_interval)
else:
self._poll_interval_func = poll_interval
self._poll_max_retries = poll_max_retries
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pytest
import requests

from hcloud import APIException, Client
from hcloud import (
APIException,
Client,
constant_backoff_function,
exponential_backoff_function,
)


class TestHetznerClient:
Expand Down Expand Up @@ -202,3 +207,24 @@ def test_request_limit_then_success(self, client, rate_limit_response):
"POST", "http://url.com", params={"argument": "value"}, timeout=2
)
assert client._requests_session.request.call_count == 2


def test_constant_backoff_function():
backoff = constant_backoff_function(interval=1.0)
max_retries = 5

for i in range(max_retries):
assert backoff(i) == 1.0


def test_exponential_backoff_function():
backoff = exponential_backoff_function(
base=1.0,
multiplier=2,
cap=60.0,
)
max_retries = 5

results = [backoff(i) for i in range(max_retries)]
assert sum(results) == 31.0
assert results == [1.0, 2.0, 4.0, 8.0, 16.0]

0 comments on commit 638da80

Please sign in to comment.