Skip to content

Commit

Permalink
chore: refactor init and extract Query API (#93)
Browse files Browse the repository at this point in the history
* refactor: extract query API

* test: add pytest.ini with custom marker definition

* fix: polars module check in query API
  • Loading branch information
alespour authored Jun 19, 2024
1 parent a030485 commit ede1956
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 73 deletions.
64 changes: 8 additions & 56 deletions influxdb_client_3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
import json
import urllib.parse

import pyarrow as pa
from pyarrow.flight import FlightClient, Ticket, FlightCallOptions
import importlib.util

from influxdb_client_3.query.query_api import QueryApi as _QueryApi
from influxdb_client_3.read_file import UploadFile
from influxdb_client_3.write_client import InfluxDBClient as _InfluxDBClient, WriteOptions, Point
from influxdb_client_3.write_client.client.exceptions import InfluxDBError
from influxdb_client_3.write_client.client.write_api import WriteApi as _WriteApi, SYNCHRONOUS, ASYNCHRONOUS, \
PointSettings
from influxdb_client_3.write_client.domain.write_precision import WritePrecision
from influxdb_client_3.version import USER_AGENT

try:
import polars as pl

polars = True
except ImportError:
polars = False
polars = importlib.util.find_spec("polars") is not None


def write_client_options(**kwargs):
Expand Down Expand Up @@ -144,23 +138,15 @@ def __init__(
**kwargs)

self._write_api = _WriteApi(influxdb_client=self._client, **self._write_client_options)
self._flight_client_options = flight_client_options or {}

if query_port_overwrite is not None:
port = query_port_overwrite

gen_opts = [
("grpc.secondary_user_agent", USER_AGENT)
]

self._flight_client_options["generic_options"] = gen_opts

if scheme == 'https':
connection_string = f"grpc+tls://{hostname}:{port}"
else:
connection_string = f"grpc+tcp://{hostname}:{port}"

self._flight_client = FlightClient(connection_string, **self._flight_client_options)
self._query_api = _QueryApi(connection_string=connection_string, token=token,
flight_client_options=flight_client_options)

def write(self, record=None, database=None, **kwargs):
"""
Expand Down Expand Up @@ -258,48 +244,14 @@ def query(self, query: str, language: str = "sql", mode: str = "all", database:
database = self._database

try:
# Create an authorization header
optargs = {
"headers": [(b"authorization", f"Bearer {self._token}".encode('utf-8'))],
"timeout": 300
}
opts = _merge_options(optargs, exclude_keys=['query_parameters'], custom=kwargs)
_options = FlightCallOptions(**opts)

#
# Ticket data
#
ticket_data = {
"database": database,
"sql_query": query,
"query_type": language
}
# add query parameters
query_parameters = kwargs.get("query_parameters", None)
if query_parameters:
ticket_data["params"] = query_parameters

ticket = Ticket(json.dumps(ticket_data).encode('utf-8'))
flight_reader = self._flight_client.do_get(ticket, _options)

mode_func = {
"all": flight_reader.read_all,
"pandas": flight_reader.read_pandas,
"polars": lambda: pl.from_arrow(flight_reader.read_all()),
"chunk": lambda: flight_reader,
"reader": flight_reader.to_reader,
"schema": lambda: flight_reader.schema

}.get(mode, flight_reader.read_all)

return mode_func() if callable(mode_func) else mode_func
except Exception as e:
return self._query_api.query(query=query, language=language, mode=mode, database=database, **kwargs)
except InfluxDBError as e:
raise e

def close(self):
"""Close the client and clean up resources."""
self._write_api.close()
self._flight_client.close()
self._query_api.close()
self._client.close()

def __enter__(self):
Expand Down
104 changes: 104 additions & 0 deletions influxdb_client_3/query/query_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Query data in InfluxDB 3."""

# coding: utf-8
import json

from pyarrow.flight import FlightClient, Ticket, FlightCallOptions, FlightStreamReader
from influxdb_client_3.version import USER_AGENT


class QueryApi(object):
"""
Implementation for '/api/v2/query' endpoint.
Example:
.. code-block:: python
from influxdb_client import InfluxDBClient
# Initialize instance of QueryApi
with InfluxDBClient(url="http://localhost:8086", token="my-token", org="my-org") as client:
query_api = client.query_api()
"""

def __init__(self,
connection_string,
token,
flight_client_options) -> None:
"""
Initialize defaults.
:param connection_string: Flight/gRPC connection string
:param token: access token
:param flight_client_options: Flight client options
"""
self._token = token
self._flight_client_options = flight_client_options or {}
self._flight_client_options["generic_options"] = [
("grpc.secondary_user_agent", USER_AGENT)
]
self._flight_client = FlightClient(connection_string, **self._flight_client_options)

def query(self, query: str, language: str, mode: str, database: str, **kwargs):
"""Query data from InfluxDB.
:param query: The query to execute on the database.
:param language: The query language.
:param mode: The mode to use for the query.
It should be one of "all", "pandas", "polars", "chunk", "reader" or "schema".
:param database: The database to query from.
:param kwargs: Additional arguments to pass to the ``FlightCallOptions headers``.
For example, it can be used to set up per request headers.
:keyword query_parameters: The query parameters to use in the query.
It should be a ``dictionary`` of key-value pairs.
:return: The query result in the specified mode.
"""
from influxdb_client_3 import polars as has_polars, _merge_options as merge_options
try:
# Create an authorization header
optargs = {
"headers": [(b"authorization", f"Bearer {self._token}".encode('utf-8'))],
"timeout": 300
}
opts = merge_options(optargs, exclude_keys=['query_parameters'], custom=kwargs)
_options = FlightCallOptions(**opts)

#
# Ticket data
#
ticket_data = {
"database": database,
"sql_query": query,
"query_type": language
}
# add query parameters
query_parameters = kwargs.get("query_parameters", None)
if query_parameters:
ticket_data["params"] = query_parameters

ticket = Ticket(json.dumps(ticket_data).encode('utf-8'))
flight_reader = self._do_get(ticket, _options)

mode_funcs = {
"all": flight_reader.read_all,
"pandas": flight_reader.read_pandas,
"chunk": lambda: flight_reader,
"reader": flight_reader.to_reader,
"schema": lambda: flight_reader.schema
}
if has_polars:
import polars as pl
mode_funcs["polars"] = lambda: pl.from_arrow(flight_reader.read_all())
mode_func = mode_funcs.get(mode, flight_reader.read_all)

return mode_func() if callable(mode_func) else mode_func
except Exception as e:
raise e

def _do_get(self, ticket: Ticket, options: FlightCallOptions = None) -> FlightStreamReader:
return self._flight_client.do_get(ticket, options)

def close(self):
"""Close the Flight client."""
self._flight_client.close()
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
integration: marks integration tests (deselect with '-m "not integration"')
8 changes: 4 additions & 4 deletions tests/test_influxdb_client_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class TestInfluxDBClient3(unittest.TestCase):

@patch('influxdb_client_3._InfluxDBClient')
@patch('influxdb_client_3._WriteApi')
@patch('influxdb_client_3.FlightClient')
def setUp(self, mock_flight_client, mock_write_api, mock_influx_db_client):
@patch('influxdb_client_3._QueryApi')
def setUp(self, mock_query_api, mock_write_api, mock_influx_db_client):
self.mock_influx_db_client = mock_influx_db_client
self.mock_write_api = mock_write_api
self.mock_flight_client = mock_flight_client
self.mock_query_api = mock_query_api
self.client = InfluxDBClient3(
host="localhost",
org="my_org",
Expand All @@ -25,7 +25,7 @@ def test_init(self):
self.assertEqual(self.client._database, "my_db")
self.assertEqual(self.client._client, self.mock_influx_db_client.return_value)
self.assertEqual(self.client._write_api, self.mock_write_api.return_value)
self.assertEqual(self.client._flight_client, self.mock_flight_client.return_value)
self.assertEqual(self.client._query_api, self.mock_query_api.return_value)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions tests/test_influxdb_client_3_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_write_and_query(self):

df = self.client.query(sql, mode="pandas", query_parameters={'type': 'used', 'test_id': test_id})

self.assertIsNotNone(df)
self.assertEqual(1, len(df))
self.assertEqual(test_id, df['test_id'][0])
self.assertEqual(123.0, df['value'][0])
19 changes: 6 additions & 13 deletions tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import struct
from unittest.mock import Mock, patch, ANY
from unittest.mock import Mock, ANY

from pyarrow import (
array,
Expand Down Expand Up @@ -66,7 +66,8 @@ class HeaderCheckServerMiddleware(ServerMiddleware):
Middleware needed to catch request headers via factory
N.B. As found in pyarrow tests
"""
def __init__(self, token):
def __init__(self, token, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = token

def sending_headers(self):
Expand Down Expand Up @@ -114,25 +115,17 @@ def test_influx_default_query_headers():

class TestQuery(unittest.TestCase):

@patch('influxdb_client_3._InfluxDBClient')
@patch('influxdb_client_3._WriteApi')
@patch('influxdb_client_3.FlightClient')
def setUp(self, mock_flight_client, mock_write_api, mock_influx_db_client):
self.mock_influx_db_client = mock_influx_db_client
self.mock_write_api = mock_write_api
self.mock_flight_client = mock_flight_client
def setUp(self):
self.client = InfluxDBClient3(
host="localhost",
org="my_org",
database="my_db",
token="my_token"
)
self.client._flight_client = mock_flight_client
self.client._write_api = mock_write_api

def test_query_without_parameters(self):
mock_do_get = Mock()
self.client._flight_client.do_get = mock_do_get
self.client._query_api._do_get = mock_do_get

self.client.query('SELECT * FROM measurement')

Expand All @@ -146,7 +139,7 @@ def test_query_without_parameters(self):

def test_query_with_parameters(self):
mock_do_get = Mock()
self.client._flight_client.do_get = mock_do_get
self.client._query_api._do_get = mock_do_get

self.client.query('SELECT * FROM measurement WHERE time > $time', query_parameters={"time": "2021-01-01"})

Expand Down

0 comments on commit ede1956

Please sign in to comment.