Skip to content

Commit

Permalink
Merge pull request #215 from hgrecco/pre-commit-mypy
Browse files Browse the repository at this point in the history
Enable mypy in pre-commit
  • Loading branch information
andrewgsavage authored Feb 2, 2024
2 parents 5f4c39d + 442336d commit f9d3a66
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 29 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ repos:
hooks:
- id: nbstripout
args: [--extra-keys=metadata.kernelspec metadata.language_info.version]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.8.0"
hooks:
- id: mypy
verbose: true
args: ["--ignore-missing-imports", "--show-error-codes"]
additional_dependencies: [
"types-requests",
"pandas-stubs",
"pint",
"matplotlib-stubs",
]
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html

import datetime
from importlib.metadata import version
from importlib.metadata import version as metadata_version

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
Expand All @@ -22,7 +22,7 @@
# built documents.

try: # pragma: no cover
version = version(project)
version = metadata_version(project)
except Exception: # pragma: no cover
# we seem to have a local copy not installed without setuptools
# so the reported version will be unknown
Expand Down
2 changes: 1 addition & 1 deletion pint_pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from importlib.metadata import version
except ImportError:
# Backport for Python < 3.8
from importlib_metadata import version
from importlib_metadata import version # type: ignore

try: # pragma: no cover
__version__ = version("pint_pandas")
Expand Down
49 changes: 28 additions & 21 deletions pint_pandas/pint_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import warnings
from importlib.metadata import version
from typing import Optional
from typing import Any, Callable, Dict, Optional, Union, cast

import numpy as np
import pandas as pd
Expand All @@ -11,15 +11,15 @@
from pandas.api.extensions import (
ExtensionArray,
ExtensionDtype,
ExtensionScalarOpsMixin,
register_dataframe_accessor,
register_extension_dtype,
register_series_accessor,
)
from pandas.api.indexers import check_array_indexer
from pandas.api.types import is_integer, is_list_like, is_object_dtype, is_string_dtype
from pandas.compat import set_function_name
from pandas.core import nanops
from pandas.core.arrays.base import ExtensionOpsMixin
from pandas.core.indexers import check_array_indexer
from pandas.core import nanops # type: ignore
from pint import Quantity as _Quantity
from pint import Unit as _Unit
from pint import compat, errors
Expand Down Expand Up @@ -47,7 +47,7 @@ class PintType(ExtensionDtype):
units: Optional[_Unit] = None # Filled in by `construct_from_..._string`
_metadata = ("units",)
_match = re.compile(r"(P|p)int\[(?P<units>.+)\]")
_cache = {}
_cache = {} # type: ignore
ureg = pint.get_application_registry()

@property
Expand Down Expand Up @@ -78,11 +78,13 @@ def __new__(cls, units=None):
units = cls.ureg.Quantity(1, units).units

try:
return cls._cache["{:P}".format(units)]
# TODO: fix when Pint implements Callable typing
# TODO: wrap string into PintFormatStr class
return cls._cache["{:P}".format(units)] # type: ignore
except KeyError:
u = object.__new__(cls)
u.units = units
cls._cache["{:P}".format(units)] = u
cls._cache["{:P}".format(units)] = u # type: ignore
return u

@classmethod
Expand Down Expand Up @@ -193,9 +195,9 @@ def __repr__(self):


_NumpyEADtype = (
pd.core.dtypes.dtypes.PandasDtype
pd.core.dtypes.dtypes.PandasDtype # type: ignore
if pandas_version_info < (2, 1)
else pd.core.dtypes.dtypes.NumpyEADtype
else pd.core.dtypes.dtypes.NumpyEADtype # type: ignore
)

dtypemap = {
Expand All @@ -215,7 +217,7 @@ def __repr__(self):
dtypeunmap = {v: k for k, v in dtypemap.items()}


class PintArray(ExtensionArray, ExtensionOpsMixin):
class PintArray(ExtensionArray, ExtensionScalarOpsMixin):
"""Implements a class to describe an array of physical quantities:
the product of an array of numerical values and a unit of measurement.
Expand All @@ -234,7 +236,7 @@ class PintArray(ExtensionArray, ExtensionOpsMixin):
"""

_data = np.array([])
_data: ExtensionArray = cast(ExtensionArray, np.array([]))
context_name = None
context_units = None

Expand Down Expand Up @@ -383,7 +385,7 @@ def isna(self):
-------
missing : np.array
"""
return self._data.isna()
return cast(np.ndarray, self._data.isna())

def astype(self, dtype, copy=True):
"""Cast to a NumPy array with 'dtype'.
Expand Down Expand Up @@ -620,11 +622,11 @@ def unique(self):
data = self._data
return self._from_sequence(unique(data), dtype=self.dtype)

def __contains__(self, item) -> bool:
def __contains__(self, item) -> Union[bool, np.bool_]:
if not isinstance(item, _Quantity):
return False
elif pd.isna(item.magnitude):
return self.isna().any()
return cast(np.ndarray, self.isna()).any()
else:
return super().__contains__(item)

Expand Down Expand Up @@ -908,11 +910,12 @@ def _reduce(self, name, *, skipna: bool = True, keepdims: bool = False, **kwds):

if isinstance(self._data, ExtensionArray):
try:
result = self._data._reduce(
# TODO: https://github.com/pandas-dev/pandas-stubs/issues/850
result = self._data._reduce( # type: ignore
name, skipna=skipna, keepdims=keepdims, **kwds
)
except NotImplementedError:
result = functions[name](self.numpy_data, **kwds)
result = cast(_Quantity, functions[name](self.numpy_data, **kwds))

if name in {"all", "any", "kurt", "skew"}:
return result
Expand All @@ -927,15 +930,18 @@ def _reduce(self, name, *, skipna: bool = True, keepdims: bool = False, **kwds):
def _accumulate(self, name: str, *, skipna: bool = True, **kwds):
if name == "cumprod":
raise TypeError("cumprod not supported for pint arrays")
functions = {
functions: Dict[
str, Callable[[np._typing._SupportsArray[np.dtype[Any]]], Any]
] = {
"cummin": np.minimum.accumulate,
"cummax": np.maximum.accumulate,
"cumsum": np.cumsum,
}

if isinstance(self._data, ExtensionArray):
try:
result = self._data._accumulate(name, **kwds)
# TODO: https://github.com/pandas-dev/pandas-stubs/issues/850
result = self._data._accumulate(name, **kwds) # type: ignore
except NotImplementedError:
result = functions[name](self.numpy_data, **kwds)

Expand Down Expand Up @@ -1181,9 +1187,10 @@ def is_pint_type(obj):

try:
# for pint < 0.21 we need to explicitly register
compat.upcast_types.append(PintArray)
# TODO: fix when Pint is properly typed for mypy
compat.upcast_types.append(PintArray) # type: ignore
except AttributeError:
# for pint = 0.21 we need to add the full names of PintArray and DataFrame,
# which is to be added in pint > 0.21
compat.upcast_type_map.setdefault("pint_pandas.pint_array.PintArray", PintArray)
compat.upcast_type_map.setdefault("pandas.core.frame.DataFrame", DataFrame)
compat.upcast_type_map.setdefault("pint_pandas.pint_array.PintArray", PintArray) # type: ignore
compat.upcast_type_map.setdefault("pandas.core.frame.DataFrame", DataFrame) # type: ignore
10 changes: 5 additions & 5 deletions pint_pandas/testsuite/test_pandas_extensiontests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use_numpy, # noqa: F401
)


from pint.errors import DimensionalityError

from pint_pandas import PintArray, PintType
Expand Down Expand Up @@ -381,9 +380,10 @@ def _get_expected_exception(
return TypeError
if isinstance(obj, pd.Series):
try:
if obj.pint.m.dtype.kind == "c":
# PintSeriesAccessor is dynamically constructed; need stubs to make it mypy-compatible
if obj.pint.m.dtype.kind == "c": # type: ignore
pytest.skip(
f"{obj.pint.m.dtype.name} {obj.dtype} does not support {op_name}"
f"{obj.pint.m.dtype.name} {obj.dtype} does not support {op_name}" # type: ignore
)
return TypeError
except AttributeError:
Expand All @@ -392,9 +392,9 @@ def _get_expected_exception(
return exc
if isinstance(other, pd.Series):
try:
if other.pint.m.dtype.kind == "c":
if other.pint.m.dtype.kind == "c": # type: ignore
pytest.skip(
f"{other.pint.m.dtype.name} {other.dtype} does not support {op_name}"
f"{other.pint.m.dtype.name} {other.dtype} does not support {op_name}" # type: ignore
)
return TypeError
except AttributeError:
Expand Down

0 comments on commit f9d3a66

Please sign in to comment.