Skip to content

Commit

Permalink
Merge pull request #72 from InfluxCommunity/67-support-polars-conversion
Browse files Browse the repository at this point in the history
67 support polars conversion
  • Loading branch information
Jayclifford345 authored Dec 20, 2023
2 parents 436930d + 05b04fe commit f18a191
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 4 deletions.
10 changes: 10 additions & 0 deletions influxdb_client_3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from pyarrow.flight import FlightClient, Ticket, FlightCallOptions
from influxdb_client_3.read_file import UploadFile
import urllib.parse
try:
import polars as pl
polars = True
except ImportError:
polars = False



Expand Down Expand Up @@ -216,7 +221,10 @@ def query(self, query, language="sql", mode="all", database=None,**kwargs ):
:param kwargs: FlightClientCallOptions for the query.
:return: The queried data.
"""
if mode == "polars" and polars is False:
raise ImportError("Polars is not installed. Please install it with `pip install polars`.")



if database is None:
database = self._database
Expand All @@ -237,9 +245,11 @@ def query(self, query, language="sql", mode="all", database=None,**kwargs ):
mode_func = {
"all": flight_reader.read_all,
"pandas": flight_reader.read_pandas,
"polars": lambda: pl.from_arrow(flight_reader.read_all()),
"chunk": lambda: flight_reader,
"reader": flight_reader.to_reader,
"schema": lambda: flight_reader.schema

}.get(mode, flight_reader.read_all)

return mode_func() if callable(mode_func) else mode_func
Expand Down
2 changes: 1 addition & 1 deletion influxdb_client_3/write_client/client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(self, influxdb_client, point_settings=None):
self._point_settings.add_default_tag(key, value)

def _append_default_tag(self, key, val, record):
from write_client import Point
from influxdb_client_3.write_client import Point
if isinstance(record, bytes) or isinstance(record, str):
pass
elif isinstance(record, Point):
Expand Down
142 changes: 142 additions & 0 deletions influxdb_client_3/write_client/client/write/dataframe_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,132 @@ def number_of_chunks(self):
return self.number_of_chunks


class PolarsDataframeSerializer:
"""Serialize DataFrame into LineProtocols."""

def __init__(self, data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, chunk_size: int = None,
**kwargs) -> None:
"""
Init serializer.
:param data_frame: Polars DataFrame to serialize
:param point_settings: Default Tags
:param precision: The precision for the unix timestamps within the body line-protocol.
:param chunk_size: The size of chunk for serializing into chunks.
:key data_frame_measurement_name: name of measurement for writing Polars DataFrame
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp.
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column
"""


self.data_frame = data_frame
self.point_settings = point_settings
self.precision = precision
self.chunk_size = chunk_size
self.measurement_name = kwargs.get("data_frame_measurement_name", "measurement")
self.tag_columns = kwargs.get("data_frame_tag_columns", [])
self.timestamp_column = kwargs.get("data_frame_timestamp_column", None)
self.timestamp_timezone = kwargs.get("data_frame_timestamp_timezone", None)

self.column_indices = {name: index for index, name in enumerate(data_frame.columns)}

#
# prepare chunks
#
if chunk_size is not None:
self.number_of_chunks = int(math.ceil(len(data_frame) / float(chunk_size)))
self.chunk_size = chunk_size
else:
self.number_of_chunks = None

def escape_value(self,value):
return str(value).translate(_ESCAPE_KEY)


def to_line_protocol(self, row):
# Filter out None or empty values for tags
tags = ""

tags = ",".join(
f'{self.escape_value(col)}={self.escape_value(row[self.column_indices[col]])}'
for col in self.tag_columns
if row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
)

if self.point_settings.defaultTags:
default_tags = ",".join(
f'{self.escape_value(key)}={self.escape_value(value)}'
for key, value in self.point_settings.defaultTags.items()
)
# Ensure there's a comma between existing tags and default tags if both are present
if tags and default_tags:
tags += ","
tags += default_tags




# add escape symbols for special characters to tags

fields = ",".join(
f"{col}=\"{row[self.column_indices[col]]}\"" if isinstance(row[self.column_indices[col]], str)
else f"{col}={row[self.column_indices[col]]}i" if isinstance(row[self.column_indices[col]], int)
else f"{col}={row[self.column_indices[col]]}"
for col in self.column_indices
if col not in self.tag_columns + [self.timestamp_column]
and row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
)

# Access the Unix timestamp
timestamp = row[self.column_indices[self.timestamp_column]]
if tags != "":
line_protocol = f"{self.measurement_name},{tags} {fields} {timestamp}"
else:
line_protocol = f"{self.measurement_name} {fields} {timestamp}"

return line_protocol


def serialize(self, chunk_idx: int = None):
from ...extras import pl

df = self.data_frame

# Convert timestamp to unix timestamp
print(self.precision)
if self.precision is None:
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="ns").alias(self.timestamp_column))
elif self.precision == 'ns':
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="ns").alias(self.timestamp_column))
elif self.precision == 'us':
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="us").alias(self.timestamp_column))
elif self.precision == 'ms':
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="ms").alias(self.timestamp_column))
elif self.precision == 's':
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="s").alias(self.timestamp_column))
else:
raise ValueError(f"Unsupported precision: {self.precision}")

if chunk_idx is None:
chunk = df
else:
logger.debug("Serialize chunk %s/%s ...", chunk_idx + 1, self.number_of_chunks)
chunk = df[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size]

# Apply the UDF to each row
line_protocol_expr = chunk.apply(self.to_line_protocol,return_dtype=pl.Object)

lp = line_protocol_expr['map'].to_list()


return lp






def data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
"""
Serialize DataFrame into LineProtocols.
Expand All @@ -295,3 +421,19 @@ def data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_W
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
""" # noqa: E501
return DataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()

def polars_data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
"""
Serialize DataFrame into LineProtocols.
:param data_frame: Pandas DataFrame to serialize
:param point_settings: Default Tags
:param precision: The precision for the unix timestamps within the body line-protocol.
:key data_frame_measurement_name: name of measurement for writing Pandas DataFrame
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp. The column can be defined as a :class:`~str` value
formatted as `2018-10-26`, `2018-10-26 12:00`, `2018-10-26 12:00:00-05:00`
or other formats and types supported by `pandas.to_datetime <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html#pandas.to_datetime>`_ - ``DataFrame``
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
""" # noqa: E501
return PolarsDataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()
14 changes: 12 additions & 2 deletions influxdb_client_3/write_client/client/write_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from influxdb_client_3.write_client.domain import WritePrecision
from influxdb_client_3.write_client.client._base import _BaseWriteApi, _HAS_DATACLASS
from influxdb_client_3.write_client.client.util.helpers import get_org_query_param
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer, PolarsDataframeSerializer
from influxdb_client_3.write_client.client.write.point import Point, DEFAULT_WRITE_PRECISION
from influxdb_client_3.write_client.client.write.retry import WritesRetry
from influxdb_client_3.write_client.rest import _UTF_8_encoding
Expand Down Expand Up @@ -460,14 +460,24 @@ def _write_batching(self, bucket, org, data,
elif isinstance(data, dict):
self._write_batching(bucket, org, Point.from_dict(data, write_precision=precision, **kwargs),
precision, **kwargs)

elif 'polars' in str(type(data)):
serializer = PolarsDataframeSerializer(data, self._point_settings, precision, self._write_options.batch_size,
**kwargs)
for chunk_idx in range(serializer.number_of_chunks):
self._write_batching(bucket, org,
serializer.serialize(chunk_idx),
precision, **kwargs)

elif 'DataFrame' in type(data).__name__:
elif 'pandas' in str(type(data)):
serializer = DataframeSerializer(data, self._point_settings, precision, self._write_options.batch_size,
**kwargs)
for chunk_idx in range(serializer.number_of_chunks):
self._write_batching(bucket, org,
serializer.serialize(chunk_idx),
precision, **kwargs)


elif hasattr(data, "_asdict"):
# noinspection PyProtectedMember
self._write_batching(bucket, org, data._asdict(), precision, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion influxdb_client_3/write_client/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@
except ModuleNotFoundError as err:
raise ImportError(f"`data_frame` requires numpy which couldn't be imported due: {err}")

__all__ = ['pd', 'np']
try:
import polars as pl
except ModuleNotFoundError as err:
raise ImportError(f"`polars_frame` requires polars which couldn't be imported due: {err}")

__all__ = ['pd', 'np', 'pl']

0 comments on commit f18a191

Please sign in to comment.