Skip to content

Commit

Permalink
[ENH] review suggestions
Browse files Browse the repository at this point in the history
Signed-off-by: jmcdonnell <jmcdonnell@fieldbox.ai>
  • Loading branch information
jmcdonnell committed May 11, 2023
1 parent 5c476be commit c4ddcb7
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 47 deletions.
104 changes: 64 additions & 40 deletions kedro-datasets/kedro_datasets/api/api_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,15 @@ class APIDataSet(AbstractDataSet[None, requests.Response]):
``APIDataSet`` can also be used to save output on a remote server using
HTTP(S) methods.
>>> import pandas as pd
>>> example_table = pd.DataFrame({"col1":["val1", "val2"], "col2":["val3", "val4"]}
>>> example_table = '{"col1":["val1", "val2"], "col2":["val3", "val4"]}'
Here we initialise our APIDataSet with the correct parameters to make requests
towards the configured remote server.
>>> data_set = APIDataSet(
method = "POST"
url = "url_of_remote_server",
save_args = {"method":"POST",
"chunk_size":1}
save_args = {"chunk_size":1}
)
On initialisation, we can specify all the necessary parameters in the save args
dictionary. The default HTTP(S) method is POST but PUT is also supported.
Expand All @@ -89,7 +88,6 @@ class APIDataSet(AbstractDataSet[None, requests.Response]):
"""

DEFAULT_SAVE_ARGS = {
"method": "POST",
"params": None,
"headers": None,
"auth": None,
Expand All @@ -105,44 +103,31 @@ def __init__(
method: str = "GET",
load_args: Dict[str, Any] = None,
credentials: Union[Tuple[str, str], List[str], AuthBase] = None,
save_args: Dict[str, Any] = None,
) -> None:
"""Creates a new instance of ``APIDataSet`` to fetch data from an API endpoint.
Args:
url: The API URL endpoint.
method: The Method of the request, GET, POST, PUT
method: The Method of the request. GET, POST, PUT are the only supported
methods
load_args: Additional parameters to be fed to requests.request.
https://requests.readthedocs.io/en/latest/api/#requests.request
credentials: Allows specifying secrets in credentials.yml.
Expected format is ``('login', 'password')`` if given as a tuple or list.
An ``AuthBase`` instance can be provided for more complex cases.
requests
https://requests.readthedocs.io/en/latest/user/quickstart/#more-complicated-post-requests
params: The url parameters of the API.
https://requests.readthedocs.io/en/latest/user/quickstart/#passing-parameters-in-urls
headers: The HTTP headers.
https://requests.readthedocs.io/en/latest/user/quickstart/#custom-headers
auth: Anything ``requests`` accepts. Normally it's either ``('login',
'password')``,
or ``AuthBase``, ``HTTPBasicAuth`` instance for more complex cases. Any
iterable will be cast to a tuple.
json: The request payload, used for POST, PUT, etc requests, passed in
to the json kwarg in the requests object.
https://requests.readthedocs.io/en/latest/user/quickstart/#more-complicated-post-requests
timeout: The wait time in seconds for a response, defaults to 1 minute.
https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts
credentials: same as ``auth``. Allows specifying ``auth`` secrets in
credentials.yml.
Expected format is ``('login', 'password')`` if given as a tuple or
list. An ``AuthBase`` instance can be provided for more complex cases.
save_args: Options for saving data on server. Includes all parameters used
during load method. Adds an optional parameter, ``chunk_size`` which determines the
size of the package sent at each request.
during load method. Adds an optional parameter, ``chunk_size`` which
determines the size of the package sent at each request.
Raises:
ValueError: if both ``auth`` in ``load_args`` and ``credentials`` are specified.
ValueError: if both ``auth`` in ``load_args`` and ``credentials`` are
specified.
"""
super().__init__()

self._load_args = load_args or {}
self._load_args_auth = self._load_args.pop("auth", None)
# GET method means load
if method == "GET":
self._params = load_args or {}

# PUT, POST, DELETE means save
elif method in ["PUT", "POST"]:
Expand All @@ -158,19 +143,19 @@ def __init__(
if credentials is not None and self._param_auth is not None:
raise ValueError("Cannot specify both auth and credentials.")

self._auth = credentials or self._load_args_auth
self._auth = credentials or self._param_auth

if "cert" in self._load_args:
self._load_args["cert"] = self._convert_type(self._load_args["cert"])
if "cert" in self._params:
self._params["cert"] = self._convert_type(self._params["cert"])

if "timeout" in self._load_args:
self._load_args["timeout"] = self._convert_type(self._load_args["timeout"])
if "timeout" in self._params:
self._params["timeout"] = self._convert_type(self._params["timeout"])

self._request_args: Dict[str, Any] = {
"url": url,
"method": method,
"auth": self._convert_type(self._auth),
**self._load_args,
**self._params,
}

@staticmethod
Expand Down Expand Up @@ -202,11 +187,50 @@ def _execute_request(self, session: Session) -> requests.Response:
return response

def _load(self) -> requests.Response:
return self._execute_request()
if self._request_args["method"] == "GET":
with sessions.Session() as session:
return self._execute_request(session)

def _save(self, data: None) -> NoReturn:
raise DataSetError(f"{self.__class__.__name__} is a read only data set type")
raise DataSetError("Only GET method is supported for load")

def _execute_save_with_chunks(
self,
json_data: List[Dict[str, Any]],
) -> requests.Response:
chunk_size = self._chunk_size
n_chunks = len(json_data) // chunk_size + 1

for i in range(n_chunks):
send_data = json_data[i * chunk_size : (i + 1) * chunk_size]
response = self._execute_save_request(json_data=send_data)

return response

def _execute_save_request(self, json_data: Any) -> requests.Response:
try:
json_.loads(json_data)
except TypeError:
self._request_args["json"] = json_.dumps(json_data)
try:
response = requests.request(**self._request_args)
response.raise_for_status()
except requests.exceptions.HTTPError as exc:
raise DataSetError("Failed to send data", exc) from exc

except OSError as exc:
raise DataSetError("Failed to connect to the remote server") from exc
return response

def _save(self, data: Any) -> requests.Response:
if self._request_args["method"] in ["PUT", "POST"]:
if isinstance(data, list):
return self._execute_save_with_chunks(json_data=data)

return self._execute_save_request(json_data=data)

raise DataSetError("Use PUT or POST methods for save")

def _exists(self) -> bool:
response = self._execute_request()
with sessions.Session() as session:
response = self._execute_request(session)
return response.ok
29 changes: 22 additions & 7 deletions kedro-datasets/tests/api/test_api_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# pylint: disable=no-member
import base64
import json
import socket

import pytest
import requests
import requests_mock
from kedro.io.core import DataSetError
from requests.auth import HTTPBasicAuth

Expand All @@ -15,6 +17,7 @@
TEST_URL = "http://example.com/api/test"
TEST_TEXT_RESPONSE_DATA = "This is a response."
TEST_JSON_REQUEST_DATA = [{"key": "value"}]
TEST_JSON_RESPONSE_DATA = [{"key": "value"}]

TEST_PARAMS = {"param": "value"}
TEST_URL_WITH_PARAMS = TEST_URL + "?param=value"
Expand All @@ -39,8 +42,14 @@ def test_request_method(self, requests_mock, method):

requests_mock.register_uri(method, TEST_URL, text=TEST_TEXT_RESPONSE_DATA)

response = api_data_set.load()
assert response.text == TEST_TEXT_RESPONSE_DATA
if method == "GET":
response = api_data_set.load()
assert response.text == TEST_TEXT_RESPONSE_DATA
else:
with pytest.raises(
DataSetError, match="Only GET method is supported for load"
):
api_data_set.load()

@pytest.mark.parametrize(
"parameters_in, url_postfix",
Expand Down Expand Up @@ -263,9 +272,17 @@ def test_socket_error(self, requests_mock):
with pytest.raises(DataSetError, match="Failed to connect"):
api_data_set.load()

def test_read_only_mode(self):
@pytest.fixture
def requests_mocker(self):
with requests_mock.Mocker() as mock:
yield mock

@pytest.mark.parametrize("method", POSSIBLE_METHODS)
def test_successful_save(self, requests_mocker, method):
"""
Saving is disabled on the data set.
When we want to save some data on a server
Given an APIDataSet class
Then check we get a response
"""
if method in ["PUT", "POST"]:
api_data_set = APIDataSet(
Expand All @@ -288,9 +305,7 @@ def test_read_only_mode(self):
method=method,
save_args={"params": TEST_PARAMS, "headers": TEST_HEADERS},
)
with pytest.raises(
DataSetError, match="Use PUT, POST or DELETE methods for save"
):
with pytest.raises(DataSetError, match="Use PUT or POST methods for save"):
api_data_set._save(TEST_SAVE_DATA)
else:
with pytest.raises(
Expand Down

0 comments on commit c4ddcb7

Please sign in to comment.