Skip to content

Commit

Permalink
add download_csv functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
bh2smith committed Sep 8, 2023
1 parent 6c1ef87 commit 90187a1
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 39 deletions.
2 changes: 0 additions & 2 deletions dune_client/api/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def get_execution_results_csv(self, job_id: str) -> ExecutionResultCSV:
if you need metadata information use get_results() or get_status()
"""
route = f"/execution/{job_id}/results/csv"
url = self._route_url(f"/execution/{job_id}/results/csv")
self.logger.debug(f"GET CSV received input url={url}")
response = self._get(route=route, raw=True)
response.raise_for_status()
return ExecutionResultCSV(data=BytesIO(response.content))
Expand Down
32 changes: 20 additions & 12 deletions dune_client/api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import time
from io import BytesIO
from typing import Union, Optional, Any

from deprecated import deprecated
Expand All @@ -15,7 +16,7 @@
QueryFailed,
ExecutionResultCSV,
)
from dune_client.query import QueryBase
from dune_client.query import QueryBase, parse_query_object_or_id


class ExtendedAPI(ExecutionAPI):
Expand Down Expand Up @@ -78,19 +79,11 @@ def get_latest_result(self, query: Union[QueryBase, str, int]) -> ResultsRespons
GET the latest results for a query_id without re-executing the query
(doesn't use execution credits)
:param query: :class:`Query` object OR query id as string | int
:param query: :class:`Query` object OR query id as string or int
https://dune.com/docs/api/api-reference/latest_results/
https://dune.com/docs/api/api-reference/get-results/latest-results
"""
if isinstance(query, QueryBase):
params = {
f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()
}
query_id = query.query_id
else:
params = None
query_id = int(query)

params, query_id = parse_query_object_or_id(query)
response_json = self._get(
route=f"/query/{query_id}/results",
params=params,
Expand All @@ -100,6 +93,21 @@ def get_latest_result(self, query: Union[QueryBase, str, int]) -> ResultsRespons
except KeyError as err:
raise DuneError(response_json, "ResultsResponse", err) from err

def download_csv(self, query: Union[QueryBase, str, int]) -> ExecutionResultCSV:
"""
Almost like an alias for `get_latest_result` but for the csv endpoint.
https://dune.com/docs/api/api-reference/get-results/latest-results
"""
params, query_id = parse_query_object_or_id(query)
response = self._get(
route=f"/query/{query_id}/results/csv", params=params, raw=True
)
response.raise_for_status()
return ExecutionResultCSV(data=BytesIO(response.content))

############################
# Plus Subscription Features
############################
def upload_csv(self, table_name: str, data: str, description: str = "") -> bool:
"""
https://dune.com/docs/api/api-reference/upload-data/?h=data+upload#endpoint
Expand Down
12 changes: 2 additions & 10 deletions dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ExecutionState,
)

from dune_client.query import QueryBase
from dune_client.query import QueryBase, parse_query_object_or_id


# pylint: disable=duplicate-code
Expand Down Expand Up @@ -181,15 +181,7 @@ async def get_latest_result(
https://dune.com/docs/api/api-reference/latest_results/
"""
if isinstance(query, QueryBase):
params = {
f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()
}
query_id = query.query_id
else:
params = None
query_id = int(query)

params, query_id = parse_query_object_or_id(query)
response_json = await self._get(
route=f"/query/{query_id}/results",
params=params,
Expand Down
2 changes: 1 addition & 1 deletion dune_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class ExecutionResultCSV:
Representation of a raw `result` in CSV format
this payload can be passed directly to
csv.reader(data) or
pandas.from_csv(data)
pandas.read_csv(data)
"""

data: BytesIO # includes all CSV rows, including the header row.
Expand Down
16 changes: 16 additions & 0 deletions dune_client/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from dune_client.types import QueryParameter


def parse_query_object_or_id(
query: Union[QueryBase, str, int],
) -> tuple[dict[str, str] | None, int]:
"""
Users are allowed to pass QueryBase or ID into some functions.
This method handles both scenarios, returning a pair of the form (params, query_id)
"""
if isinstance(query, QueryBase):
params = {f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()}
query_id = query.query_id
else:
params = None
query_id = int(query)
return params, query_id


@dataclass
class QueryBase:
"""Basic data structure constituting a Dune Analytics Query."""
Expand Down
71 changes: 59 additions & 12 deletions tests/e2e/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import dotenv
import pandas

from dune_client.models import (
ExecutionState,
Expand Down Expand Up @@ -33,6 +34,18 @@ def setUp(self) -> None:
)
self.valid_api_key = os.environ["DUNE_API_KEY"]

def copy_query_and_change_parameters(self) -> QueryBase:
new_query = copy.copy(self.query)
new_query.params = [
# Using all different values for parameters.
QueryParameter.text_type(name="TextField", value="different word"),
QueryParameter.number_type(name="NumberField", value=22),
QueryParameter.date_type(name="DateField", value="1991-01-01 00:00:00"),
QueryParameter.enum_type(name="ListField", value="Option 2"),
]
self.assertNotEqual(self.query.parameters(), new_query.parameters())
return new_query

def test_from_env_constructor(self):
try:
DuneClient.from_env()
Expand Down Expand Up @@ -64,19 +77,9 @@ def test_run_query_dataframe(self):
self.assertGreater(len(pd), 0)

def test_parameters_recognized(self):
query = copy.copy(self.query)
new_params = [
# Using all different values for parameters.
QueryParameter.text_type(name="TextField", value="different word"),
QueryParameter.number_type(name="NumberField", value=22),
QueryParameter.date_type(name="DateField", value="1991-01-01 00:00:00"),
QueryParameter.enum_type(name="ListField", value="Option 2"),
]
query.params = new_params
self.assertEqual(query.parameters(), new_params)

new_query = self.copy_query_and_change_parameters()
dune = DuneClient(self.valid_api_key)
results = dune.run_query(query)
results = dune.run_query(new_query)
self.assertEqual(
results.get_rows(),
[
Expand Down Expand Up @@ -194,6 +197,50 @@ def test_upload_csv_success(self):
True,
)

def test_download_csv_success_by_id(self):
client = DuneClient(self.valid_api_key)
new_query = self.copy_query_and_change_parameters()
# Run query with new parameters
client.run_query(new_query)
# Download CSV by query_id
result_csv = client.download_csv(self.query.query_id)
# Expect that the csv returns the latest execution results (i.e. those that were just run)
self.assertEqual(
pandas.read_csv(result_csv.data).to_dict(orient="records"),
[
{
"text_field": "different word",
"number_field": 22,
"date_field": "1991-01-01 00:00:00.000",
"list_field": "Option 2",
}
],
)

def test_download_csv_success_with_params(self):
client = DuneClient(self.valid_api_key)
# Download CSV with query and given parameters.
result_csv = client.download_csv(self.query)
# Expect the result to be relative to values of given parameters.
#################################################################
# Note that we could compare results with
# ",".join([p.value for p in self.query.parameters()]) + "\n"
# but there seems to be a discrepancy with the date string values.
# Specifically 1991-01-01 00:00:00.000
# vs 1991-01-01 00:00:00
#################################################################
self.assertEqual(
pandas.read_csv(result_csv.data).to_dict(orient="records"),
[
{
"date_field": "2022-05-04 00:00:00.000",
"list_field": "Option 1",
"number_field": 3.1415926535,
"text_field": "Plain Text",
}
],
)


@unittest.skip("This is an enterprise only endpoint that can no longer be tested.")
class TestCRUDOps(unittest.TestCase):
Expand Down
28 changes: 26 additions & 2 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import unittest
from datetime import datetime

from dune_client.query import QueryBase
from dune_client.query import QueryBase, parse_query_object_or_id
from dune_client.types import QueryParameter


class TestQueryMonitor(unittest.TestCase):
class TestQueryBase(unittest.TestCase):
def setUp(self) -> None:
self.date = datetime(year=1985, month=3, day=10)
self.query_params = [
Expand Down Expand Up @@ -60,6 +60,30 @@ def test_hash(self):
query2 = QueryBase(query_id=1, params=[QueryParameter.number_type("num", 1)])
self.assertNotEqual(hash(query1), hash(query2))

def test_parse_object_or_id(self):
expected_params = {
"params.Date": "2021-01-01 12:34:56",
"params.Enum": "option1",
"params.Number": "12",
"params.Text": "plain text",
}
expected_query_id = self.query.query_id
# Query Object
self.assertEqual(
parse_query_object_or_id(self.query), (expected_params, expected_query_id)
)
# Query ID (integer)
expected_params = None
self.assertEqual(
parse_query_object_or_id(self.query.query_id),
(expected_params, expected_query_id),
)
# Query ID (string)
self.assertEqual(
parse_query_object_or_id(str(self.query.query_id)),
(expected_params, expected_query_id),
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 90187a1

Please sign in to comment.