Skip to content

Commit

Permalink
Add more type hints to APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Jul 14, 2020
1 parent 6e6ad04 commit 8434a1f
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 97 deletions.
2 changes: 1 addition & 1 deletion eland/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def resolve(self) -> str:
for task in self._tasks:
if task.op_name == "__add__":
value = f"({value} + {task.object.resolve()})"
elif task.op_name == "__truediv__":
elif task.op_name in ("__truediv__", "__div__"):
value = f"({value} / {task.object.resolve()})"
elif task.op_name == "__floordiv__":
value = f"Math.floor({value} / {task.object.resolve()})"
Expand Down
38 changes: 27 additions & 11 deletions eland/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from io import StringIO
import re
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -631,7 +631,7 @@ def es_info(self):
def info_es(self):
return self.es_info()

def es_query(self, query):
def es_query(self, query) -> "DataFrame":
"""Applies an Elasticsearch DSL query to the current DataFrame.
Parameters
Expand Down Expand Up @@ -705,7 +705,7 @@ def _index_summary(self):

def info(
self, verbose=None, buf=None, max_cols=None, memory_usage=None, null_counts=None
):
) -> None:
"""
Print a concise summary of a DataFrame.
Expand Down Expand Up @@ -822,7 +822,7 @@ def _verbose_repr():
dtype = dtypes.iloc[i]
col = pprint_thing(col)

line_no = _put_str(" {num}".format(num=i), space_num)
line_no = _put_str(f" {i}", space_num)

count = ""
if show_counts:
Expand Down Expand Up @@ -1223,7 +1223,7 @@ def to_csv(
}
return self._query_compiler.to_csv(**kwargs)

def to_pandas(self, show_progress: bool = False) -> "DataFrame":
def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
"""
Utility method to convert eland.Dataframe to pandas.Dataframe
Expand All @@ -1233,10 +1233,10 @@ def to_pandas(self, show_progress: bool = False) -> "DataFrame":
"""
return self._query_compiler.to_pandas(show_progress=show_progress)

def _empty_pd_df(self):
def _empty_pd_df(self) -> pd.DataFrame:
return self._query_compiler._empty_pd_ef()

def select_dtypes(self, include=None, exclude=None):
def select_dtypes(self, include=None, exclude=None) -> "DataFrame":
"""
Return a subset of the DataFrame's columns based on the column dtypes.
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def select_dtypes(self, include=None, exclude=None):
return self._getitem_array(empty_df.columns)

@property
def shape(self):
def shape(self) -> Tuple[int, int]:
"""
Return a tuple representing the dimensionality of the DataFrame.
Expand All @@ -1299,7 +1299,23 @@ def shape(self):

return num_rows, num_columns

def keys(self):
@property
def ndim(self) -> int:
"""
Returns 2 by definition of a DataFrame
Returns
-------
int
By definition 2
See Also
--------
:pandas_api_docs:`pandas.DataFrame.ndim`
"""
return 2

def keys(self) -> pd.Index:
"""
Return columns
Expand Down Expand Up @@ -1381,7 +1397,7 @@ def aggregate(self, func, axis=0, *args, **kwargs):

hist = gfx.ed_hist_frame

def query(self, expr):
def query(self, expr) -> "DataFrame":
"""
Query the columns of a DataFrame with a boolean expression.
Expand Down Expand Up @@ -1474,7 +1490,7 @@ def filter(
like: Optional[str] = None,
regex: Optional[str] = None,
axis: Optional[Union[int, str]] = None,
):
) -> "DataFrame":
"""
Subset the dataframe rows or columns according to the specified index labels.
Note that this routine does not filter a dataframe on its
Expand Down
32 changes: 20 additions & 12 deletions eland/field_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
is_string_dtype,
)
from pandas.core.dtypes.inference import is_list_like
from typing import NamedTuple, Optional, Mapping, Dict, Any, TYPE_CHECKING
from typing import NamedTuple, Optional, Mapping, Dict, Any, TYPE_CHECKING, List, Set

if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland import DataFrame


ES_FLOAT_TYPES = {"double", "float", "half_float", "scaled_float"}
ES_INTEGER_TYPES = {"long", "integer", "short", "byte"}
ES_COMPATIBLE_TYPES = {
ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"}
ES_INTEGER_TYPES: Set[str] = {"long", "integer", "short", "byte"}
ES_COMPATIBLE_TYPES: Dict[str, Set[str]] = {
"double": ES_FLOAT_TYPES,
"scaled_float": ES_FLOAT_TYPES,
"float": ES_FLOAT_TYPES,
Expand Down Expand Up @@ -80,7 +81,7 @@ def is_bool(self) -> bool:
def np_dtype(self):
return np.dtype(self.pd_dtype)

def is_es_agg_compatible(self, es_agg):
def is_es_agg_compatible(self, es_agg) -> bool:
# Cardinality works for all types
# Numerics and bools work for all aggs
if es_agg == "cardinality" or self.is_numeric or self.is_bool:
Expand Down Expand Up @@ -115,7 +116,7 @@ class FieldMappings:
or es_field_name.keyword (if exists) or None
"""

ES_DTYPE_TO_PD_DTYPE = {
ES_DTYPE_TO_PD_DTYPE: Dict[str, str] = {
"text": "object",
"keyword": "object",
"long": "int64",
Expand All @@ -133,7 +134,7 @@ class FieldMappings:
}

# the labels for each column (display_name is index)
column_labels = [
column_labels: List[str] = [
"es_field_name",
"is_source",
"es_dtype",
Expand All @@ -145,7 +146,12 @@ class FieldMappings:
"aggregatable_es_field_name",
]

def __init__(self, client=None, index_pattern=None, display_names=None):
def __init__(
self,
client: "Elasticsearch",
index_pattern: str,
display_names: Optional[List[str]] = None,
):
"""
Parameters
----------
Expand Down Expand Up @@ -184,7 +190,9 @@ def __init__(self, client=None, index_pattern=None, display_names=None):
self.display_names = display_names

@staticmethod
def _extract_fields_from_mapping(mappings, source_only=False, date_format=None):
def _extract_fields_from_mapping(
mappings: Dict[str, Any], source_only: bool = False
) -> Dict[str, str]:
"""
Extract all field names and types from a mapping.
```
Expand Down Expand Up @@ -256,10 +264,10 @@ def _extract_fields_from_mapping(mappings, source_only=False, date_format=None):

# Recurse until we get a 'type: xxx'
def flatten(x, name=""):
if type(x) is dict:
if isinstance(x, dict):
for a in x:
if (
a == "type" and type(x[a]) is str
if a == "type" and isinstance(
x[a], str
): # 'type' can be a name of a field
field_name = name[:-1]
field_type = x[a]
Expand Down
41 changes: 22 additions & 19 deletions eland/ndframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

import sys
from abc import ABC, abstractmethod
from typing import Tuple

from typing import TYPE_CHECKING, Tuple
import pandas as pd
from eland.query_compiler import QueryCompiler


if TYPE_CHECKING:
from eland.index import Index

"""
NDFrame
---------
Expand Down Expand Up @@ -73,7 +77,8 @@ def __init__(
)
self._query_compiler = _query_compiler

def _get_index(self):
@property
def index(self) -> "Index":
"""
Return eland index referencing Elasticsearch field to index a DataFrame/Series
Expand All @@ -100,10 +105,8 @@ def _get_index(self):
"""
return self._query_compiler.index

index = property(_get_index)

@property
def dtypes(self):
def dtypes(self) -> pd.Series:
"""
Return the pandas dtypes in the DataFrame. Elasticsearch types are mapped
to pandas dtypes via Mappings._es_dtype_to_pd_dtype.__doc__
Expand All @@ -129,7 +132,7 @@ def dtypes(self):
"""
return self._query_compiler.dtypes

def _build_repr(self, num_rows):
def _build_repr(self, num_rows) -> pd.DataFrame:
# self could be Series or DataFrame
if len(self.index) <= num_rows:
return self.to_pandas()
Expand All @@ -144,11 +147,11 @@ def _build_repr(self, num_rows):

return head.append(tail)

def __sizeof__(self):
def __sizeof__(self) -> int:
# Don't default to pandas, just return approximation TODO - make this more accurate
return sys.getsizeof(self._query_compiler)

def __len__(self):
def __len__(self) -> int:
"""Gets the length of the DataFrame.
Returns:
Expand All @@ -159,7 +162,7 @@ def __len__(self):
def _es_info(self, buf):
self._query_compiler.es_info(buf)

def mean(self, numeric_only=True):
def mean(self, numeric_only: bool = True) -> pd.Series:
"""
Return mean value for each numeric column
Expand Down Expand Up @@ -191,7 +194,7 @@ def mean(self, numeric_only=True):
"""
return self._query_compiler.mean(numeric_only=numeric_only)

def sum(self, numeric_only=True):
def sum(self, numeric_only: bool = True) -> pd.Series:
"""
Return sum for each numeric column
Expand Down Expand Up @@ -223,7 +226,7 @@ def sum(self, numeric_only=True):
"""
return self._query_compiler.sum(numeric_only=numeric_only)

def min(self, numeric_only=True):
def min(self, numeric_only: bool = True) -> pd.Series:
"""
Return the minimum value for each numeric column
Expand Down Expand Up @@ -255,7 +258,7 @@ def min(self, numeric_only=True):
"""
return self._query_compiler.min(numeric_only=numeric_only)

def var(self, numeric_only=True):
def var(self, numeric_only: bool = True) -> pd.Series:
"""
Return variance for each numeric column
Expand Down Expand Up @@ -285,7 +288,7 @@ def var(self, numeric_only=True):
"""
return self._query_compiler.var(numeric_only=numeric_only)

def std(self, numeric_only=True):
def std(self, numeric_only: bool = True) -> pd.Series:
"""
Return standard deviation for each numeric column
Expand Down Expand Up @@ -315,7 +318,7 @@ def std(self, numeric_only=True):
"""
return self._query_compiler.std(numeric_only=numeric_only)

def median(self, numeric_only=True):
def median(self, numeric_only: bool = True) -> pd.Series:
"""
Return the median value for each numeric column
Expand Down Expand Up @@ -345,7 +348,7 @@ def median(self, numeric_only=True):
"""
return self._query_compiler.median(numeric_only=numeric_only)

def max(self, numeric_only=True):
def max(self, numeric_only: bool = True) -> pd.Series:
"""
Return the maximum value for each numeric column
Expand Down Expand Up @@ -377,7 +380,7 @@ def max(self, numeric_only=True):
"""
return self._query_compiler.max(numeric_only=numeric_only)

def nunique(self):
def nunique(self) -> pd.Series:
"""
Return cardinality of each field.
Expand Down Expand Up @@ -423,7 +426,7 @@ def nunique(self):
"""
return self._query_compiler.nunique()

def mad(self, numeric_only=True):
def mad(self, numeric_only: bool = True) -> pd.Series:
"""
Return standard deviation for each numeric column
Expand Down Expand Up @@ -456,7 +459,7 @@ def mad(self, numeric_only=True):
def _hist(self, num_bins):
return self._query_compiler._hist(num_bins)

def describe(self):
def describe(self) -> pd.DataFrame:
"""
Generate descriptive statistics that summarize the central tendency, dispersion and shape of a
dataset’s distribution, excluding NaN values.
Expand Down
Loading

0 comments on commit 8434a1f

Please sign in to comment.