Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor init and extract Query API #93

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 8 additions & 56 deletions influxdb_client_3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
import json
import urllib.parse

import pyarrow as pa
from pyarrow.flight import FlightClient, Ticket, FlightCallOptions
import importlib.util

from influxdb_client_3.query.query_api import QueryApi as _QueryApi
from influxdb_client_3.read_file import UploadFile
from influxdb_client_3.write_client import InfluxDBClient as _InfluxDBClient, WriteOptions, Point
from influxdb_client_3.write_client.client.exceptions import InfluxDBError
from influxdb_client_3.write_client.client.write_api import WriteApi as _WriteApi, SYNCHRONOUS, ASYNCHRONOUS, \
PointSettings
from influxdb_client_3.write_client.domain.write_precision import WritePrecision
from influxdb_client_3.version import USER_AGENT

try:
import polars as pl

polars = True
except ImportError:
polars = False
polars = importlib.util.find_spec("polars") is not None


def write_client_options(**kwargs):
Expand Down Expand Up @@ -144,23 +138,15 @@
**kwargs)

self._write_api = _WriteApi(influxdb_client=self._client, **self._write_client_options)
self._flight_client_options = flight_client_options or {}

if query_port_overwrite is not None:
port = query_port_overwrite

gen_opts = [
("grpc.secondary_user_agent", USER_AGENT)
]

self._flight_client_options["generic_options"] = gen_opts

if scheme == 'https':
connection_string = f"grpc+tls://{hostname}:{port}"
else:
connection_string = f"grpc+tcp://{hostname}:{port}"

self._flight_client = FlightClient(connection_string, **self._flight_client_options)
self._query_api = _QueryApi(connection_string=connection_string, token=token,
flight_client_options=flight_client_options)

def write(self, record=None, database=None, **kwargs):
"""
Expand Down Expand Up @@ -258,48 +244,14 @@
database = self._database

try:
# Create an authorization header
optargs = {
"headers": [(b"authorization", f"Bearer {self._token}".encode('utf-8'))],
"timeout": 300
}
opts = _merge_options(optargs, exclude_keys=['query_parameters'], custom=kwargs)
_options = FlightCallOptions(**opts)

#
# Ticket data
#
ticket_data = {
"database": database,
"sql_query": query,
"query_type": language
}
# add query parameters
query_parameters = kwargs.get("query_parameters", None)
if query_parameters:
ticket_data["params"] = query_parameters

ticket = Ticket(json.dumps(ticket_data).encode('utf-8'))
flight_reader = self._flight_client.do_get(ticket, _options)

mode_func = {
"all": flight_reader.read_all,
"pandas": flight_reader.read_pandas,
"polars": lambda: pl.from_arrow(flight_reader.read_all()),
"chunk": lambda: flight_reader,
"reader": flight_reader.to_reader,
"schema": lambda: flight_reader.schema

}.get(mode, flight_reader.read_all)

return mode_func() if callable(mode_func) else mode_func
except Exception as e:
return self._query_api.query(query=query, language=language, mode=mode, database=database, **kwargs)
except InfluxDBError as e:

Check warning on line 248 in influxdb_client_3/__init__.py

View check run for this annotation

Codecov / codecov/patch

influxdb_client_3/__init__.py#L248

Added line #L248 was not covered by tests
raise e

def close(self):
"""Close the client and clean up resources."""
self._write_api.close()
self._flight_client.close()
self._query_api.close()
self._client.close()

def __enter__(self):
Expand Down
105 changes: 105 additions & 0 deletions influxdb_client_3/query/query_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Query data in InfluxDB 3."""

# coding: utf-8
import json

from pyarrow.flight import FlightClient, Ticket, FlightCallOptions, FlightStreamReader
from influxdb_client_3.version import USER_AGENT


class QueryApi(object):
"""
Implementation for '/api/v2/query' endpoint.

Example:
.. code-block:: python

from influxdb_client import InfluxDBClient


# Initialize instance of QueryApi
with InfluxDBClient(url="http://localhost:8086", token="my-token", org="my-org") as client:
query_api = client.query_api()
"""

def __init__(self,
connection_string,
token,
flight_client_options,
**kwargs) -> None:
alespour marked this conversation as resolved.
Show resolved Hide resolved
"""
Initialize defaults.

:param connection_string: Flight/gRPC connection string
:param token: access token
:param flight_client_options: Flight client options
"""
self._token = token
self._flight_client_options = flight_client_options or {}
self._flight_client_options["generic_options"] = [
("grpc.secondary_user_agent", USER_AGENT)
]
self._flight_client = FlightClient(connection_string, **self._flight_client_options)

def query(self, query: str, language: str, mode: str, database: str, **kwargs):
"""Query data from InfluxDB.

:param query: The query to execute on the database.
:param language: The query language.
:param mode: The mode to use for the query.
It should be one of "all", "pandas", "polars", "chunk", "reader" or "schema".
:param database: The database to query from.
:param kwargs: Additional arguments to pass to the ``FlightCallOptions headers``.
For example, it can be used to set up per request headers.
:keyword query_parameters: The query parameters to use in the query.
It should be a ``dictionary`` of key-value pairs.
:return: The query result in the specified mode.
"""
from influxdb_client_3 import polars as has_polars, _merge_options as merge_options
try:
# Create an authorization header
optargs = {
"headers": [(b"authorization", f"Bearer {self._token}".encode('utf-8'))],
"timeout": 300
}
opts = merge_options(optargs, exclude_keys=['query_parameters'], custom=kwargs)
_options = FlightCallOptions(**opts)

#
# Ticket data
#
ticket_data = {
"database": database,
"sql_query": query,
"query_type": language
}
# add query parameters
query_parameters = kwargs.get("query_parameters", None)
if query_parameters:
ticket_data["params"] = query_parameters

ticket = Ticket(json.dumps(ticket_data).encode('utf-8'))
flight_reader = self._do_get(ticket, _options)

mode_funcs = {
"all": flight_reader.read_all,
"pandas": flight_reader.read_pandas,
"chunk": lambda: flight_reader,
"reader": flight_reader.to_reader,
"schema": lambda: flight_reader.schema
}
if has_polars:
import polars as pl
mode_funcs["polars"] = lambda: pl.from_arrow(flight_reader.read_all())
mode_func = mode_funcs.get(mode, flight_reader.read_all)

return mode_func() if callable(mode_func) else mode_func
except Exception as e:
raise e

Check warning on line 98 in influxdb_client_3/query/query_api.py

View check run for this annotation

Codecov / codecov/patch

influxdb_client_3/query/query_api.py#L97-L98

Added lines #L97 - L98 were not covered by tests

def _do_get(self, ticket: Ticket, options: FlightCallOptions = None) -> FlightStreamReader:
return self._flight_client.do_get(ticket, options)

def close(self):
"""Close the Flight client."""
self._flight_client.close()
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
integration: marks integration tests (deselect with '-m "not integration"')
8 changes: 4 additions & 4 deletions tests/test_influxdb_client_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class TestInfluxDBClient3(unittest.TestCase):

@patch('influxdb_client_3._InfluxDBClient')
@patch('influxdb_client_3._WriteApi')
@patch('influxdb_client_3.FlightClient')
def setUp(self, mock_flight_client, mock_write_api, mock_influx_db_client):
@patch('influxdb_client_3._QueryApi')
def setUp(self, mock_query_api, mock_write_api, mock_influx_db_client):
self.mock_influx_db_client = mock_influx_db_client
self.mock_write_api = mock_write_api
self.mock_flight_client = mock_flight_client
self.mock_query_api = mock_query_api
self.client = InfluxDBClient3(
host="localhost",
org="my_org",
Expand All @@ -25,7 +25,7 @@ def test_init(self):
self.assertEqual(self.client._database, "my_db")
self.assertEqual(self.client._client, self.mock_influx_db_client.return_value)
self.assertEqual(self.client._write_api, self.mock_write_api.return_value)
self.assertEqual(self.client._flight_client, self.mock_flight_client.return_value)
self.assertEqual(self.client._query_api, self.mock_query_api.return_value)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions tests/test_influxdb_client_3_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_write_and_query(self):

df = self.client.query(sql, mode="pandas", query_parameters={'type': 'used', 'test_id': test_id})

self.assertIsNotNone(df)
self.assertEqual(1, len(df))
self.assertEqual(test_id, df['test_id'][0])
self.assertEqual(123.0, df['value'][0])
19 changes: 6 additions & 13 deletions tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import struct
from unittest.mock import Mock, patch, ANY
from unittest.mock import Mock, ANY

from pyarrow import (
array,
Expand Down Expand Up @@ -66,7 +66,8 @@ class HeaderCheckServerMiddleware(ServerMiddleware):
Middleware needed to catch request headers via factory
N.B. As found in pyarrow tests
"""
def __init__(self, token):
def __init__(self, token, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = token

def sending_headers(self):
Expand Down Expand Up @@ -114,25 +115,17 @@ def test_influx_default_query_headers():

class TestQuery(unittest.TestCase):

@patch('influxdb_client_3._InfluxDBClient')
@patch('influxdb_client_3._WriteApi')
@patch('influxdb_client_3.FlightClient')
def setUp(self, mock_flight_client, mock_write_api, mock_influx_db_client):
self.mock_influx_db_client = mock_influx_db_client
self.mock_write_api = mock_write_api
self.mock_flight_client = mock_flight_client
def setUp(self):
self.client = InfluxDBClient3(
host="localhost",
org="my_org",
database="my_db",
token="my_token"
)
self.client._flight_client = mock_flight_client
self.client._write_api = mock_write_api

def test_query_without_parameters(self):
mock_do_get = Mock()
self.client._flight_client.do_get = mock_do_get
self.client._query_api._do_get = mock_do_get

self.client.query('SELECT * FROM measurement')

Expand All @@ -146,7 +139,7 @@ def test_query_without_parameters(self):

def test_query_with_parameters(self):
mock_do_get = Mock()
self.client._flight_client.do_get = mock_do_get
self.client._query_api._do_get = mock_do_get

self.client.query('SELECT * FROM measurement WHERE time > $time', query_parameters={"time": "2021-01-01"})

Expand Down
Loading