From 8375261da3b84d6fece97263c7bea40ad2a6cfcf Mon Sep 17 00:00:00 2001 From: Jonas L Date: Tue, 2 Jul 2024 15:15:47 +0200 Subject: [PATCH] feat: add `trace_id` to API exceptions (#404) - Do not use helper functions to raise exceptions, - Add the request trace ID to the API exceptions, - Refactor the code. --- hcloud/_client.py | 65 ++++++++++++++++++++------------------- hcloud/_exceptions.py | 18 +++++++++-- tests/unit/test_client.py | 27 +++++++++++++++- 3 files changed, 76 insertions(+), 34 deletions(-) diff --git a/hcloud/_client.py b/hcloud/_client.py index 84c4735..81aabe2 100644 --- a/hcloud/_client.py +++ b/hcloud/_client.py @@ -1,7 +1,7 @@ from __future__ import annotations import time -from typing import NoReturn, Protocol +from typing import Protocol import requests @@ -190,20 +190,6 @@ def _get_headers(self) -> dict: } return headers - def _raise_exception_from_response(self, response: requests.Response) -> NoReturn: - raise APIException( - code=response.status_code, - message=response.reason, - details={"content": response.content}, - ) - - def _raise_exception_from_content(self, content: dict) -> NoReturn: - raise APIException( - code=content["error"]["code"], - message=content["error"]["message"], - details=content["error"]["details"], - ) - def request( # type: ignore[no-untyped-def] self, method: str, @@ -229,23 +215,40 @@ def request( # type: ignore[no-untyped-def] **kwargs, ) - content = {} + trace_id = response.headers.get("X-Correlation-Id") + payload = {} try: if len(response.content) > 0: - content = response.json() - except (TypeError, ValueError): - self._raise_exception_from_response(response) + payload = response.json() + except (TypeError, ValueError) as exc: + raise APIException( + code=response.status_code, + message=response.reason, + details={"content": response.content}, + trace_id=trace_id, + ) from exc if not response.ok: - if content: - assert isinstance(content, dict) - if content["error"]["code"] == "rate_limit_exceeded" and _tries < 5: - time.sleep(_tries * self._retry_wait_time) - _tries = _tries + 1 - return self.request(method, url, _tries=_tries, **kwargs) - - self._raise_exception_from_content(content) - else: - self._raise_exception_from_response(response) - - return content + if not payload or "error" not in payload: + raise APIException( + code=response.status_code, + message=response.reason, + details={"content": response.content}, + trace_id=trace_id, + ) + + error: dict = payload["error"] + + if error["code"] == "rate_limit_exceeded" and _tries < 5: + time.sleep(_tries * self._retry_wait_time) + _tries = _tries + 1 + return self.request(method, url, _tries=_tries, **kwargs) + + raise APIException( + code=error["code"], + message=error["message"], + details=error["details"], + trace_id=trace_id, + ) + + return payload diff --git a/hcloud/_exceptions.py b/hcloud/_exceptions.py index 877083f..cb6e60f 100644 --- a/hcloud/_exceptions.py +++ b/hcloud/_exceptions.py @@ -10,8 +10,22 @@ class HCloudException(Exception): class APIException(HCloudException): """There was an error while performing an API Request""" - def __init__(self, code: int | str, message: str | None, details: Any): - super().__init__(code if message is None and isinstance(code, str) else message) + def __init__( + self, + code: int | str, + message: str, + details: Any, + *, + trace_id: str | None = None, + ): + extras = [str(code)] + if trace_id is not None: + extras.append(trace_id) + + error = f"{message} ({', '.join(extras)})" + + super().__init__(error) self.code = code self.message = message self.details = details + self.trace_id = trace_id diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4a1da1a..b70c142 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -102,6 +102,31 @@ def test_request_fails(self, client, fail_response): assert error.message == "invalid input in field 'broken_field': is too long" assert error.details["fields"][0]["name"] == "broken_field" + def test_request_fails_trace_id(self, client, response): + response.headers["X-Correlation-Id"] = "67ed842dc8bc8673" + response.status_code = 409 + response._content = json.dumps( + { + "error": { + "code": "conflict", + "message": "some conflict", + "details": None, + } + } + ).encode("utf-8") + + client._requests_session.request.return_value = response + with pytest.raises(APIException) as exception_info: + client.request( + "POST", "http://url.com", params={"argument": "value"}, timeout=2 + ) + error = exception_info.value + assert error.code == "conflict" + assert error.message == "some conflict" + assert error.details is None + assert error.trace_id == "67ed842dc8bc8673" + assert str(error) == "some conflict (conflict, 67ed842dc8bc8673)" + def test_request_500(self, client, fail_response): fail_response.status_code = 500 fail_response.reason = "Internal Server Error" @@ -153,7 +178,7 @@ def test_request_500_empty_content(self, client, fail_response): assert error.code == 500 assert error.message == "Internal Server Error" assert error.details["content"] == "" - assert str(error) == "Internal Server Error" + assert str(error) == "Internal Server Error (500)" def test_request_limit(self, client, rate_limit_response): client._retry_wait_time = 0