From 90187a16a94cc26b25e077f7600f5196da922635 Mon Sep 17 00:00:00 2001 From: Ben Smith Date: Fri, 8 Sep 2023 12:45:14 +0200 Subject: [PATCH] add download_csv functionality --- dune_client/api/execution.py | 2 - dune_client/api/extensions.py | 32 ++++++++++------ dune_client/client_async.py | 12 +----- dune_client/models.py | 2 +- dune_client/query.py | 16 ++++++++ tests/e2e/test_client.py | 71 +++++++++++++++++++++++++++++------ tests/unit/test_query.py | 28 +++++++++++++- 7 files changed, 124 insertions(+), 39 deletions(-) diff --git a/dune_client/api/execution.py b/dune_client/api/execution.py index 78875d6..e06924e 100644 --- a/dune_client/api/execution.py +++ b/dune_client/api/execution.py @@ -83,8 +83,6 @@ def get_execution_results_csv(self, job_id: str) -> ExecutionResultCSV: if you need metadata information use get_results() or get_status() """ route = f"/execution/{job_id}/results/csv" - url = self._route_url(f"/execution/{job_id}/results/csv") - self.logger.debug(f"GET CSV received input url={url}") response = self._get(route=route, raw=True) response.raise_for_status() return ExecutionResultCSV(data=BytesIO(response.content)) diff --git a/dune_client/api/extensions.py b/dune_client/api/extensions.py index 0a41677..fa0ca58 100644 --- a/dune_client/api/extensions.py +++ b/dune_client/api/extensions.py @@ -3,6 +3,7 @@ """ import time +from io import BytesIO from typing import Union, Optional, Any from deprecated import deprecated @@ -15,7 +16,7 @@ QueryFailed, ExecutionResultCSV, ) -from dune_client.query import QueryBase +from dune_client.query import QueryBase, parse_query_object_or_id class ExtendedAPI(ExecutionAPI): @@ -78,19 +79,11 @@ def get_latest_result(self, query: Union[QueryBase, str, int]) -> ResultsRespons GET the latest results for a query_id without re-executing the query (doesn't use execution credits) - :param query: :class:`Query` object OR query id as string | int + :param query: :class:`Query` object OR query id as string or int - https://dune.com/docs/api/api-reference/latest_results/ + https://dune.com/docs/api/api-reference/get-results/latest-results """ - if isinstance(query, QueryBase): - params = { - f"params.{p.key}": p.to_dict()["value"] for p in query.parameters() - } - query_id = query.query_id - else: - params = None - query_id = int(query) - + params, query_id = parse_query_object_or_id(query) response_json = self._get( route=f"/query/{query_id}/results", params=params, @@ -100,6 +93,21 @@ def get_latest_result(self, query: Union[QueryBase, str, int]) -> ResultsRespons except KeyError as err: raise DuneError(response_json, "ResultsResponse", err) from err + def download_csv(self, query: Union[QueryBase, str, int]) -> ExecutionResultCSV: + """ + Almost like an alias for `get_latest_result` but for the csv endpoint. + https://dune.com/docs/api/api-reference/get-results/latest-results + """ + params, query_id = parse_query_object_or_id(query) + response = self._get( + route=f"/query/{query_id}/results/csv", params=params, raw=True + ) + response.raise_for_status() + return ExecutionResultCSV(data=BytesIO(response.content)) + + ############################ + # Plus Subscription Features + ############################ def upload_csv(self, table_name: str, data: str, description: str = "") -> bool: """ https://dune.com/docs/api/api-reference/upload-data/?h=data+upload#endpoint diff --git a/dune_client/client_async.py b/dune_client/client_async.py index f3d7d42..d9e9c6c 100644 --- a/dune_client/client_async.py +++ b/dune_client/client_async.py @@ -28,7 +28,7 @@ ExecutionState, ) -from dune_client.query import QueryBase +from dune_client.query import QueryBase, parse_query_object_or_id # pylint: disable=duplicate-code @@ -181,15 +181,7 @@ async def get_latest_result( https://dune.com/docs/api/api-reference/latest_results/ """ - if isinstance(query, QueryBase): - params = { - f"params.{p.key}": p.to_dict()["value"] for p in query.parameters() - } - query_id = query.query_id - else: - params = None - query_id = int(query) - + params, query_id = parse_query_object_or_id(query) response_json = await self._get( route=f"/query/{query_id}/results", params=params, diff --git a/dune_client/models.py b/dune_client/models.py index 5482a55..5721ead 100644 --- a/dune_client/models.py +++ b/dune_client/models.py @@ -183,7 +183,7 @@ class ExecutionResultCSV: Representation of a raw `result` in CSV format this payload can be passed directly to csv.reader(data) or - pandas.from_csv(data) + pandas.read_csv(data) """ data: BytesIO # includes all CSV rows, including the header row. diff --git a/dune_client/query.py b/dune_client/query.py index 1a0da0a..c961cc1 100644 --- a/dune_client/query.py +++ b/dune_client/query.py @@ -9,6 +9,22 @@ from dune_client.types import QueryParameter +def parse_query_object_or_id( + query: Union[QueryBase, str, int], +) -> tuple[dict[str, str] | None, int]: + """ + Users are allowed to pass QueryBase or ID into some functions. + This method handles both scenarios, returning a pair of the form (params, query_id) + """ + if isinstance(query, QueryBase): + params = {f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()} + query_id = query.query_id + else: + params = None + query_id = int(query) + return params, query_id + + @dataclass class QueryBase: """Basic data structure constituting a Dune Analytics Query.""" diff --git a/tests/e2e/test_client.py b/tests/e2e/test_client.py index a862145..2c9ed70 100644 --- a/tests/e2e/test_client.py +++ b/tests/e2e/test_client.py @@ -4,6 +4,7 @@ import unittest import dotenv +import pandas from dune_client.models import ( ExecutionState, @@ -33,6 +34,18 @@ def setUp(self) -> None: ) self.valid_api_key = os.environ["DUNE_API_KEY"] + def copy_query_and_change_parameters(self) -> QueryBase: + new_query = copy.copy(self.query) + new_query.params = [ + # Using all different values for parameters. + QueryParameter.text_type(name="TextField", value="different word"), + QueryParameter.number_type(name="NumberField", value=22), + QueryParameter.date_type(name="DateField", value="1991-01-01 00:00:00"), + QueryParameter.enum_type(name="ListField", value="Option 2"), + ] + self.assertNotEqual(self.query.parameters(), new_query.parameters()) + return new_query + def test_from_env_constructor(self): try: DuneClient.from_env() @@ -64,19 +77,9 @@ def test_run_query_dataframe(self): self.assertGreater(len(pd), 0) def test_parameters_recognized(self): - query = copy.copy(self.query) - new_params = [ - # Using all different values for parameters. - QueryParameter.text_type(name="TextField", value="different word"), - QueryParameter.number_type(name="NumberField", value=22), - QueryParameter.date_type(name="DateField", value="1991-01-01 00:00:00"), - QueryParameter.enum_type(name="ListField", value="Option 2"), - ] - query.params = new_params - self.assertEqual(query.parameters(), new_params) - + new_query = self.copy_query_and_change_parameters() dune = DuneClient(self.valid_api_key) - results = dune.run_query(query) + results = dune.run_query(new_query) self.assertEqual( results.get_rows(), [ @@ -194,6 +197,50 @@ def test_upload_csv_success(self): True, ) + def test_download_csv_success_by_id(self): + client = DuneClient(self.valid_api_key) + new_query = self.copy_query_and_change_parameters() + # Run query with new parameters + client.run_query(new_query) + # Download CSV by query_id + result_csv = client.download_csv(self.query.query_id) + # Expect that the csv returns the latest execution results (i.e. those that were just run) + self.assertEqual( + pandas.read_csv(result_csv.data).to_dict(orient="records"), + [ + { + "text_field": "different word", + "number_field": 22, + "date_field": "1991-01-01 00:00:00.000", + "list_field": "Option 2", + } + ], + ) + + def test_download_csv_success_with_params(self): + client = DuneClient(self.valid_api_key) + # Download CSV with query and given parameters. + result_csv = client.download_csv(self.query) + # Expect the result to be relative to values of given parameters. + ################################################################# + # Note that we could compare results with + # ",".join([p.value for p in self.query.parameters()]) + "\n" + # but there seems to be a discrepancy with the date string values. + # Specifically 1991-01-01 00:00:00.000 + # vs 1991-01-01 00:00:00 + ################################################################# + self.assertEqual( + pandas.read_csv(result_csv.data).to_dict(orient="records"), + [ + { + "date_field": "2022-05-04 00:00:00.000", + "list_field": "Option 1", + "number_field": 3.1415926535, + "text_field": "Plain Text", + } + ], + ) + @unittest.skip("This is an enterprise only endpoint that can no longer be tested.") class TestCRUDOps(unittest.TestCase): diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 2c73f81..a803d93 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -1,11 +1,11 @@ import unittest from datetime import datetime -from dune_client.query import QueryBase +from dune_client.query import QueryBase, parse_query_object_or_id from dune_client.types import QueryParameter -class TestQueryMonitor(unittest.TestCase): +class TestQueryBase(unittest.TestCase): def setUp(self) -> None: self.date = datetime(year=1985, month=3, day=10) self.query_params = [ @@ -60,6 +60,30 @@ def test_hash(self): query2 = QueryBase(query_id=1, params=[QueryParameter.number_type("num", 1)]) self.assertNotEqual(hash(query1), hash(query2)) + def test_parse_object_or_id(self): + expected_params = { + "params.Date": "2021-01-01 12:34:56", + "params.Enum": "option1", + "params.Number": "12", + "params.Text": "plain text", + } + expected_query_id = self.query.query_id + # Query Object + self.assertEqual( + parse_query_object_or_id(self.query), (expected_params, expected_query_id) + ) + # Query ID (integer) + expected_params = None + self.assertEqual( + parse_query_object_or_id(self.query.query_id), + (expected_params, expected_query_id), + ) + # Query ID (string) + self.assertEqual( + parse_query_object_or_id(str(self.query.query_id)), + (expected_params, expected_query_id), + ) + if __name__ == "__main__": unittest.main()