From bd7956ea7210a87e8085a1ffd8107134f4371dff Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 20 Oct 2020 16:26:58 -0500 Subject: [PATCH] Support typed 'elasticsearch-py' and add 'py.typed' --- MANIFEST.in | 1 + eland/common.py | 15 +++++++++----- eland/etl.py | 6 +++--- eland/ml/ml_model.py | 10 +++++---- eland/ndframe.py | 45 ++++++++++++++++++++++++----------------- eland/py.typed | 0 eland/query_compiler.py | 8 ++++---- noxfile.py | 1 + setup.py | 3 +++ 9 files changed, 55 insertions(+), 34 deletions(-) create mode 100644 eland/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index 6006776d..2ad7baa0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include LICENSE.txt include README.md +recursive-include eland py.typed diff --git a/eland/common.py b/eland/common.py index 647d598e..a5a2afd4 100644 --- a/eland/common.py +++ b/eland/common.py @@ -22,7 +22,7 @@ import numpy as np # type: ignore import pandas as pd # type: ignore -from elasticsearch import Elasticsearch # type: ignore +from elasticsearch import Elasticsearch # Default number of rows displayed (different to pandas where ALL could be displayed) DEFAULT_NUM_ROWS_DISPLAYED = 60 @@ -86,7 +86,7 @@ def from_string(order: str) -> "SortOrder": def elasticsearch_date_to_pandas_date( - value: Union[int, str], date_format: str + value: Union[int, str], date_format: Optional[str] ) -> pd.Timestamp: """ Given a specific Elasticsearch format for a date datatype, returns the @@ -298,6 +298,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]: """Tags the current ES client with a cached '_eland_es_version' property if one doesn't exist yet for the current Elasticsearch version. """ + eland_es_version: Tuple[int, int, int] if not hasattr(es_client, "_eland_es_version"): version_info = es_client.info()["version"]["number"] match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info) @@ -306,6 +307,10 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]: f"Unable to determine Elasticsearch version. " f"Received: {version_info}" ) - major, minor, patch = [int(x) for x in match.groups()] - es_client._eland_es_version = (major, minor, patch) - return cast(Tuple[int, int, int], es_client._eland_es_version) + eland_es_version = cast( + Tuple[int, int, int], tuple([int(x) for x in match.groups()]) + ) + es_client._eland_es_version = eland_es_version # type: ignore + else: + eland_es_version = es_client._eland_es_version # type: ignore + return eland_es_version diff --git a/eland/etl.py b/eland/etl.py index 00e6c0f7..8f5cf87d 100644 --- a/eland/etl.py +++ b/eland/etl.py @@ -20,8 +20,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import pandas as pd # type: ignore -from elasticsearch import Elasticsearch # type: ignore -from elasticsearch.helpers import parallel_bulk # type: ignore +from elasticsearch import Elasticsearch +from elasticsearch.helpers import parallel_bulk from pandas.io.parsers import _c_parser_defaults # type: ignore from eland import DataFrame @@ -240,7 +240,7 @@ def action_generator( pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index ), thread_count=thread_count, - chunk_size=chunksize / thread_count, + chunk_size=int(chunksize / thread_count), ), maxlen=0, ) diff --git a/eland/ml/ml_model.py b/eland/ml/ml_model.py index 3ca08948..41c22a0d 100644 --- a/eland/ml/ml_model.py +++ b/eland/ml/ml_model.py @@ -18,7 +18,7 @@ import warnings from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast -import elasticsearch # type: ignore +import elasticsearch import numpy as np # type: ignore from eland.common import ensure_es_client, es_version @@ -447,11 +447,13 @@ def _trained_model_config(self) -> Dict[str, Any]: # In Elasticsearch 7.7 and earlier you can't get # target type without pulling the model definition # so we check the version first. - kwargs = {} if es_version(self._client) < (7, 8): - kwargs["include_model_definition"] = True + resp = self._client.ml.get_trained_models( + model_id=self._model_id, include_model_definition=True + ) + else: + resp = self._client.ml.get_trained_models(model_id=self._model_id) - resp = self._client.ml.get_trained_models(model_id=self._model_id, **kwargs) if resp["count"] > 1: raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous") elif resp["count"] == 0: diff --git a/eland/ndframe.py b/eland/ndframe.py index 17ef2886..0d603746 100644 --- a/eland/ndframe.py +++ b/eland/ndframe.py @@ -17,13 +17,15 @@ import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import pandas as pd +import pandas as pd # type: ignore from eland.query_compiler import QueryCompiler if TYPE_CHECKING: + from elasticsearch import Elasticsearch + from eland.index import Index """ @@ -55,12 +57,14 @@ class NDFrame(ABC): def __init__( self, - es_client=None, - es_index_pattern=None, - columns=None, - es_index_field=None, - _query_compiler=None, - ): + es_client: Optional[ + Union[str, List[str], Tuple[str, ...], "Elasticsearch"] + ] = None, + es_index_pattern: Optional[str] = None, + columns: Optional[List[str]] = None, + es_index_field: Optional[str] = None, + _query_compiler: Optional[QueryCompiler] = None, + ) -> None: """ pandas.DataFrame/Series like API that proxies into Elasticsearch index(es). @@ -134,7 +138,7 @@ def dtypes(self) -> pd.Series: return self._query_compiler.dtypes @property - def es_dtypes(self): + def es_dtypes(self) -> pd.Series: """ Return the Elasticsearch dtypes in the index @@ -155,7 +159,7 @@ def es_dtypes(self): """ return self._query_compiler.es_dtypes - def _build_repr(self, num_rows) -> pd.DataFrame: + def _build_repr(self, num_rows: int) -> pd.DataFrame: # self could be Series or DataFrame if len(self.index) <= num_rows: return self.to_pandas() @@ -639,20 +643,25 @@ def describe(self) -> pd.DataFrame: return self._query_compiler.describe() @abstractmethod - def to_pandas(self, show_progress=False): - pass + def to_pandas(self, show_progress: bool = False) -> pd.DataFrame: + raise NotImplementedError @abstractmethod - def head(self, n=5): - pass + def head(self, n: int = 5) -> "NDFrame": + raise NotImplementedError @abstractmethod - def tail(self, n=5): - pass + def tail(self, n: int = 5) -> "NDFrame": + raise NotImplementedError @abstractmethod - def sample(self, n=None, frac=None, random_state=None): - pass + def sample( + self, + n: Optional[int] = None, + frac: Optional[float] = None, + random_state: Optional[int] = None, + ) -> "NDFrame": + raise NotImplementedError @property def shape(self) -> Tuple[int, ...]: diff --git a/eland/py.typed b/eland/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/eland/query_compiler.py b/eland/query_compiler.py index 956e402d..c1395fdd 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -94,11 +94,11 @@ def __init__( self._operations = Operations() @property - def index(self): + def index(self) -> Index: return self._index @property - def columns(self): + def columns(self) -> pd.Index: columns = self._mappings.display_names return pd.Index(columns) @@ -120,11 +120,11 @@ def add_scripted_field(self, scripted_field_name, display_name, pd_dtype): return result @property - def dtypes(self): + def dtypes(self) -> pd.Series: return self._mappings.dtypes() @property - def es_dtypes(self): + def es_dtypes(self) -> pd.Series: return self._mappings.es_dtypes() # END Index, columns, and dtypes objects diff --git a/noxfile.py b/noxfile.py index 52517f53..3f3a170f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -68,6 +68,7 @@ def format(session): @nox.session(reuse_venv=True) def lint(session): session.install("black", "flake8", "mypy", "isort") + session.install("--pre", "elasticsearch") session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("black", "--check", "--target-version=py36", *SOURCE_FILES) session.run("isort", "--check", *SOURCE_FILES) diff --git a/setup.py b/setup.py index 8b36258a..4ebc4dbe 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,9 @@ packages=find_packages(include=["eland", "eland.*"]), install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"], python_requires=">=3.6", + package_data={"eland": ["py.typed"]}, + include_package_data=True, + zip_safe=False, extras_require={ "xgboost": ["xgboost>=0.90,<2"], "scikit-learn": ["scikit-learn>=0.22.1,<1"],