diff --git a/src/solana/exceptions.py b/src/solana/exceptions.py new file mode 100644 index 00000000..662024f2 --- /dev/null +++ b/src/solana/exceptions.py @@ -0,0 +1,54 @@ +"""Exceptions native to solana-py.""" +from typing import Callable, Any + + +class SolanaExceptionBase(Exception): + """Base class for Solana-py exceptions.""" + + def __init__(self, exc: Exception, func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> None: + """Init.""" + super().__init__() + self.error_msg = self._build_error_message(exc, func, *args, **kwargs) + + @staticmethod + def _build_error_message(exc: Exception, func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> str: + return f"{type(exc)} raised in {func} invokation" + + +class SolanaRpcException(SolanaExceptionBase): + """Class for Solana-py RPC exceptions.""" + + @staticmethod + def _build_error_message(exc: Exception, func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> str: + rpc_method = args[1] + return f'{type(exc)} raised in "{rpc_method}" endpoint request' + + +def handle_exceptions(internal_exception_cls, *exception_types_caught): + """Decorator for handling non-async exception.""" + + def func_decorator(func): + def argument_decorator(*args, **kwargs): + try: + return func(*args, **kwargs) + except exception_types_caught as exc: + raise internal_exception_cls(exc, func, *args, **kwargs) + + return argument_decorator + + return func_decorator + + +def handle_async_exceptions(internal_exception_cls, *exception_types_caught): + """Decorator for handling async exception.""" + + def func_decorator(func): + async def argument_decorator(*args, **kwargs): + try: + return await func(*args, **kwargs) + except exception_types_caught as exc: + raise internal_exception_cls(exc, func, *args, **kwargs) + + return argument_decorator + + return func_decorator diff --git a/src/solana/rpc/providers/async_http.py b/src/solana/rpc/providers/async_http.py index 8959d790..aa0631c1 100644 --- a/src/solana/rpc/providers/async_http.py +++ b/src/solana/rpc/providers/async_http.py @@ -6,6 +6,7 @@ from ..types import RPCMethod, RPCResponse from .async_base import AsyncBaseProvider from .core import _HTTPProviderCore, DEFAULT_TIMEOUT +from ...exceptions import handle_async_exceptions, SolanaRpcException class AsyncHTTPProvider(AsyncBaseProvider, _HTTPProviderCore): @@ -20,6 +21,7 @@ def __str__(self) -> str: """String definition for HTTPProvider.""" return f"Async HTTP RPC connection {self.endpoint_uri}" + @handle_async_exceptions(SolanaRpcException, Exception) async def make_request(self, method: RPCMethod, *params: Any) -> RPCResponse: """Make an async HTTP request to an http rpc endpoint.""" request_kwargs = self._before_request(method=method, params=params, is_async=True) diff --git a/src/solana/rpc/providers/http.py b/src/solana/rpc/providers/http.py index a7eb6fbe..a4e388ff 100644 --- a/src/solana/rpc/providers/http.py +++ b/src/solana/rpc/providers/http.py @@ -6,6 +6,7 @@ from ..types import RPCMethod, RPCResponse from .base import BaseProvider from .core import _HTTPProviderCore +from ...exceptions import handle_exceptions, SolanaRpcException class HTTPProvider(BaseProvider, _HTTPProviderCore): @@ -15,6 +16,7 @@ def __str__(self) -> str: """String definition for HTTPProvider.""" return f"HTTP RPC connection {self.endpoint_uri}" + @handle_exceptions(SolanaRpcException, requests.exceptions.RequestException) def make_request(self, method: RPCMethod, *params: Any) -> RPCResponse: """Make an HTTP request to an http rpc endpoint.""" request_kwargs = self._before_request(method=method, params=params, is_async=False) diff --git a/tests/conftest.py b/tests/conftest.py index 13222534..de820817 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,6 +118,20 @@ def freeze_authority() -> Keypair: return Keypair.from_seed(bytes([6] * PublicKey.LENGTH)) +@pytest.fixture(scope="session") +def unit_test_http_client() -> Client: + """Client to be used in unit tests.""" + client = Client(commitment=Processed) + return client + + +@pytest.fixture(scope="session") +def unit_test_http_client_async() -> AsyncClient: + """Async client to be used in unit tests.""" + client = AsyncClient(commitment=Processed) + return client + + @pytest.mark.integration @pytest.fixture(scope="session") def test_http_client(docker_services) -> Client: diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py new file mode 100644 index 00000000..71c355ce --- /dev/null +++ b/tests/unit/test_async_client.py @@ -0,0 +1,21 @@ +from unittest.mock import patch +from requests.exceptions import ReadTimeout + +import pytest + +from solana.exceptions import SolanaRpcException + + +@pytest.mark.asyncio +async def test_async_client_http_exception(unit_test_http_client_async): + """Test AsyncClient raises native Solana-py exceptions.""" + + with patch("httpx.AsyncClient.post") as post_mock: + post_mock.side_effect = ReadTimeout() + with pytest.raises(SolanaRpcException) as exc_info: + await unit_test_http_client_async.get_epoch_info() + assert exc_info.type == SolanaRpcException + assert ( + exc_info.value.error_msg + == " raised in \"getEpochInfo\" endpoint request" + ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py new file mode 100644 index 00000000..faba7a4d --- /dev/null +++ b/tests/unit/test_client.py @@ -0,0 +1,19 @@ +from unittest.mock import patch +from requests.exceptions import ReadTimeout +import pytest + +from solana.exceptions import SolanaRpcException + + +def test_client_http_exception(unit_test_http_client): + """Test AsyncClient raises native Solana-py exceptions.""" + + with patch("requests.post") as post_mock: + post_mock.side_effect = ReadTimeout() + with pytest.raises(SolanaRpcException) as exc_info: + unit_test_http_client.get_epoch_info() + assert exc_info.type == SolanaRpcException + assert ( + exc_info.value.error_msg + == " raised in \"getEpochInfo\" endpoint request" + )