Skip to content

Commit

Permalink
Support typed 'elasticsearch-py' and add 'py.typed'
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson authored Oct 20, 2020
1 parent 05a24cb commit bd7956e
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 34 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include LICENSE.txt
include README.md
recursive-include eland py.typed
15 changes: 10 additions & 5 deletions eland/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy as np # type: ignore
import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
from elasticsearch import Elasticsearch

# Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60
Expand Down Expand Up @@ -86,7 +86,7 @@ def from_string(order: str) -> "SortOrder":


def elasticsearch_date_to_pandas_date(
value: Union[int, str], date_format: str
value: Union[int, str], date_format: Optional[str]
) -> pd.Timestamp:
"""
Given a specific Elasticsearch format for a date datatype, returns the
Expand Down Expand Up @@ -298,6 +298,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
"""Tags the current ES client with a cached '_eland_es_version'
property if one doesn't exist yet for the current Elasticsearch version.
"""
eland_es_version: Tuple[int, int, int]
if not hasattr(es_client, "_eland_es_version"):
version_info = es_client.info()["version"]["number"]
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
Expand All @@ -306,6 +307,10 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
f"Unable to determine Elasticsearch version. "
f"Received: {version_info}"
)
major, minor, patch = [int(x) for x in match.groups()]
es_client._eland_es_version = (major, minor, patch)
return cast(Tuple[int, int, int], es_client._eland_es_version)
eland_es_version = cast(
Tuple[int, int, int], tuple([int(x) for x in match.groups()])
)
es_client._eland_es_version = eland_es_version # type: ignore
else:
eland_es_version = es_client._eland_es_version # type: ignore
return eland_es_version
6 changes: 3 additions & 3 deletions eland/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union

import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
from elasticsearch.helpers import parallel_bulk # type: ignore
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
from pandas.io.parsers import _c_parser_defaults # type: ignore

from eland import DataFrame
Expand Down Expand Up @@ -240,7 +240,7 @@ def action_generator(
pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index
),
thread_count=thread_count,
chunk_size=chunksize / thread_count,
chunk_size=int(chunksize / thread_count),
),
maxlen=0,
)
Expand Down
10 changes: 6 additions & 4 deletions eland/ml/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import elasticsearch # type: ignore
import elasticsearch
import numpy as np # type: ignore

from eland.common import ensure_es_client, es_version
Expand Down Expand Up @@ -447,11 +447,13 @@ def _trained_model_config(self) -> Dict[str, Any]:
# In Elasticsearch 7.7 and earlier you can't get
# target type without pulling the model definition
# so we check the version first.
kwargs = {}
if es_version(self._client) < (7, 8):
kwargs["include_model_definition"] = True
resp = self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=True
)
else:
resp = self._client.ml.get_trained_models(model_id=self._model_id)

resp = self._client.ml.get_trained_models(model_id=self._model_id, **kwargs)
if resp["count"] > 1:
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
elif resp["count"] == 0:
Expand Down
45 changes: 27 additions & 18 deletions eland/ndframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import pandas as pd
import pandas as pd # type: ignore

from eland.query_compiler import QueryCompiler

if TYPE_CHECKING:
from elasticsearch import Elasticsearch

from eland.index import Index

"""
Expand Down Expand Up @@ -55,12 +57,14 @@
class NDFrame(ABC):
def __init__(
self,
es_client=None,
es_index_pattern=None,
columns=None,
es_index_field=None,
_query_compiler=None,
):
es_client: Optional[
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
es_index_pattern: Optional[str] = None,
columns: Optional[List[str]] = None,
es_index_field: Optional[str] = None,
_query_compiler: Optional[QueryCompiler] = None,
) -> None:
"""
pandas.DataFrame/Series like API that proxies into Elasticsearch index(es).
Expand Down Expand Up @@ -134,7 +138,7 @@ def dtypes(self) -> pd.Series:
return self._query_compiler.dtypes

@property
def es_dtypes(self):
def es_dtypes(self) -> pd.Series:
"""
Return the Elasticsearch dtypes in the index
Expand All @@ -155,7 +159,7 @@ def es_dtypes(self):
"""
return self._query_compiler.es_dtypes

def _build_repr(self, num_rows) -> pd.DataFrame:
def _build_repr(self, num_rows: int) -> pd.DataFrame:
# self could be Series or DataFrame
if len(self.index) <= num_rows:
return self.to_pandas()
Expand Down Expand Up @@ -639,20 +643,25 @@ def describe(self) -> pd.DataFrame:
return self._query_compiler.describe()

@abstractmethod
def to_pandas(self, show_progress=False):
pass
def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
raise NotImplementedError

@abstractmethod
def head(self, n=5):
pass
def head(self, n: int = 5) -> "NDFrame":
raise NotImplementedError

@abstractmethod
def tail(self, n=5):
pass
def tail(self, n: int = 5) -> "NDFrame":
raise NotImplementedError

@abstractmethod
def sample(self, n=None, frac=None, random_state=None):
pass
def sample(
self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "NDFrame":
raise NotImplementedError

@property
def shape(self) -> Tuple[int, ...]:
Expand Down
Empty file added eland/py.typed
Empty file.
8 changes: 4 additions & 4 deletions eland/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def __init__(
self._operations = Operations()

@property
def index(self):
def index(self) -> Index:
return self._index

@property
def columns(self):
def columns(self) -> pd.Index:
columns = self._mappings.display_names

return pd.Index(columns)
Expand All @@ -120,11 +120,11 @@ def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
return result

@property
def dtypes(self):
def dtypes(self) -> pd.Series:
return self._mappings.dtypes()

@property
def es_dtypes(self):
def es_dtypes(self) -> pd.Series:
return self._mappings.es_dtypes()

# END Index, columns, and dtypes objects
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def format(session):
@nox.session(reuse_venv=True)
def lint(session):
session.install("black", "flake8", "mypy", "isort")
session.install("--pre", "elasticsearch")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
session.run("isort", "--check", *SOURCE_FILES)
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
packages=find_packages(include=["eland", "eland.*"]),
install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"],
python_requires=">=3.6",
package_data={"eland": ["py.typed"]},
include_package_data=True,
zip_safe=False,
extras_require={
"xgboost": ["xgboost>=0.90,<2"],
"scikit-learn": ["scikit-learn>=0.22.1,<1"],
Expand Down

0 comments on commit bd7956e

Please sign in to comment.