Skip to content

Commit

Permalink
Remove Collectors
Browse files Browse the repository at this point in the history
  • Loading branch information
V1NAY8 committed Aug 16, 2021
1 parent 76d83ea commit dad552a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 100 deletions.
2 changes: 1 addition & 1 deletion eland/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ def to_csv(
doublequote=True,
escapechar=None,
decimal=".",
):
) -> Optional[str]:
"""
Write Elasticsearch data to a comma-separated values (csv) file.
Expand Down
35 changes: 12 additions & 23 deletions eland/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,29 +526,18 @@ def csv_to_eland( # type: ignore

first_write = True
for chunk in reader:
if first_write:
pandas_to_eland(
chunk,
es_client,
es_dest_index,
es_if_exists=es_if_exists,
chunksize=chunksize,
es_refresh=es_refresh,
es_dropna=es_dropna,
es_type_overrides=es_type_overrides,
)
first_write = False
else:
pandas_to_eland(
chunk,
es_client,
es_dest_index,
es_if_exists="append",
chunksize=chunksize,
es_refresh=es_refresh,
es_dropna=es_dropna,
es_type_overrides=es_type_overrides,
)
pandas_to_eland(
chunk,
es_client,
es_dest_index,
chunksize=chunksize,
es_refresh=es_refresh,
es_dropna=es_dropna,
es_type_overrides=es_type_overrides,
# es_if_exists should be 'append' except on the first call to pandas_to_eland()
es_if_exists=(es_if_exists if first_write else "append"),
)
first_write = False

# Now create an eland.DataFrame that references the new index
return DataFrame(es_client, es_index_pattern=es_dest_index)
87 changes: 12 additions & 75 deletions eland/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from eland.actions import PostProcessingAction
from eland.common import (
DEFAULT_CSV_BATCH_OUTPUT_SIZE,
DEFAULT_PAGINATION_SIZE,
DEFAULT_PIT_KEEP_ALIVE,
DEFAULT_SEARCH_SIZE,
Expand Down Expand Up @@ -1198,32 +1197,23 @@ def describe(self, query_compiler: "QueryCompiler") -> pd.DataFrame:

def to_pandas(
self, query_compiler: "QueryCompiler", show_progress: bool = False
) -> None:

collector = PandasDataFrameCollector(show_progress)

self._es_results(query_compiler, collector)
) -> pd.DataFrame:
df = self._es_results(query_compiler, show_progress)

return collector._df
return df

def to_csv(
self,
query_compiler: "QueryCompiler",
show_progress: bool = False,
**kwargs: Union[bool, str],
) -> None:

collector = PandasToCSVCollector(show_progress, **kwargs)

self._es_results(query_compiler, collector)

return collector._ret
) -> Optional[str]:
df = self._es_results(query_compiler, show_progress)
return df.to_csv(**kwargs) # type: ignore[no-any-return]

def _es_results(
self,
query_compiler: "QueryCompiler",
collector: Union["PandasToCSVCollector", "PandasDataFrameCollector"],
) -> None:
self, query_compiler: "QueryCompiler", show_progress: bool = False
) -> pd.DataFrame:
query_params, post_processing = self._resolve_tasks(query_compiler)

result_size, sort_params = Operations._query_params_to_size_and_sort(
Expand All @@ -1250,9 +1240,11 @@ def _es_results(
)
)

_, df = query_compiler._es_results_to_pandas(es_results)
_, df = query_compiler._es_results_to_pandas(
results=es_results, show_progress=show_progress
)
df = self._apply_df_post_processing(df, post_processing)
collector.collect(df)
return df

def index_count(self, query_compiler: "QueryCompiler", field: str) -> int:
# field is the index field so count values
Expand Down Expand Up @@ -1455,61 +1447,6 @@ def quantile_to_percentile(quantile: Union[int, float]) -> float:
return float(min(100, max(0, quantile * 100)))


class PandasToCSVCollector:
def __init__(self, show_progress: bool, **kwargs: Union[bool, str]) -> None:
self._args = kwargs
self._show_progress = show_progress
self._ret = None
self._first_time = True

def collect(self, df: "pd.DataFrame") -> None:
# If this is the first time we collect results, then write header, otherwise don't write header
# and append results
if self._first_time:
self._first_time = False
df.to_csv(**self._args)
else:
# Don't write header, and change mode to append
self._args["header"] = False
self._args["mode"] = "a"
df.to_csv(**self._args)

@staticmethod
def batch_size() -> int:
# By default read n docs and then dump to csv
batch_size: int = DEFAULT_CSV_BATCH_OUTPUT_SIZE
return batch_size

@property
def show_progress(self) -> bool:
return self._show_progress


class PandasDataFrameCollector:
def __init__(self, show_progress: bool) -> None:
self._df = None
self._show_progress = show_progress

def collect(self, df: "pd.DataFrame") -> None:
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
# is only called once.
if self._df is not None:
raise RuntimeError(
"Logic error in execution, this method must only be called once for this"
"collector - batch_size == None"
)
self._df = df

@staticmethod
def batch_size() -> None:
# Do not change (see notes on collect)
return None

@property
def show_progress(self) -> bool:
return self._show_progress


def search_yield_hits(
query_compiler: "QueryCompiler",
body: Dict[str, Any],
Expand Down
2 changes: 1 addition & 1 deletion eland/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def to_pandas(self, show_progress: bool = False):
return self._operations.to_pandas(self, show_progress)

# To CSV
def to_csv(self, **kwargs):
def to_csv(self, **kwargs) -> Optional[str]:
"""Serialises Eland Dataframe to CSV
Returns:
Expand Down
22 changes: 22 additions & 0 deletions tests/dataframe/test_to_csv_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import ast
import time
from io import StringIO

import pandas as pd
from pandas.testing import assert_frame_equal
Expand Down Expand Up @@ -99,3 +100,24 @@ def test_to_csv_full(self):

# clean up index
ES_TEST_CLIENT.indices.delete(test_index)

def test_pd_to_csv_without_filepath(self):

ed_flights = self.ed_flights()
pd_flights = self.pd_flights()

ret = ed_flights.to_csv()
results = StringIO(ret)
# Converting back from csv is messy as pd_flights is created from a json file
pd_from_csv = pd.read_csv(
results,
index_col=0,
converters={
"DestLocation": lambda x: ast.literal_eval(x),
"OriginLocation": lambda x: ast.literal_eval(x),
},
)
pd_from_csv.index = pd_from_csv.index.map(str)
pd_from_csv.timestamp = pd.to_datetime(pd_from_csv.timestamp)

assert_frame_equal(pd_flights, pd_from_csv)

0 comments on commit dad552a

Please sign in to comment.