Skip to content

Commit

Permalink
Refactor upcast type check to check actual types rather than names
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Dec 27, 2019
1 parent 434852c commit 5ce5c46
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Pint Changelog
- Improved compatbility for upcast types like xarray's DataArray or Dataset, to which
Pint Quantities now fully defer for arithmetic and NumPy operations. A collection of
basic tests for proper deferral has been added (for full integration tests, see
xarray's test suite). The list of names of upcast types is available at
xarray's test suite). The list of upcast types is available at
`pint.compat.upcast_types` in the API.
(Issue #959, Thanks Jon Thielen)
- Moved docstrings to Numpy Docs
Expand Down
6 changes: 3 additions & 3 deletions docs/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ upcast types to which Pint defers (see
- ``Series``, as defined by pandas
- ``DataArray``, ``Dataset``, and ``Variable``, as defined by xarray

If your application requires extension of this collection of types, the collection of
type names is available in Pint's API at ``pint.compat.upcast_types``. Note that these
are also the types to which a Quantity object will defer for arithmetic operations.
If your application requires extension of this collection of types, it is available in
Pint's API at ``pint.compat.upcast_types``. Note that these are also the types to which
a Quantity object will defer for arithmetic operations.

To achive these function and ufunc overrides, Pint uses the ``__array_function__`` and
``__array_ufunc__`` protocols respectively, as recommened by NumPy. This means that
Expand Down
33 changes: 28 additions & 5 deletions pint/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,36 @@ def _to_magnitude(value, force_ndarray=False):
babel_parse = babel_units = missing_dependency("Babel") # noqa: F811

# Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast and
# downcast/wrappable types
upcast_types = ["PintArray", "Series", "DataArray", "Dataset", "Variable"]
# downcast/wrappable types using guarded imports
upcast_types = []

# pint-pandas (PintArray)
try:
from pintpandas import PintArray

upcast_types.append(PintArray)
except ImportError:
pass

# Pandas (Series)
try:
from pandas import Series

upcast_types.append(Series)
except ImportError:
pass

# xarray (DataArray, Dataset, Variable)
try:
from xarray import DataArray, Dataset, Variable

upcast_types += [DataArray, Dataset, Variable]
except ImportError:
pass


def is_upcast_type(other):
"""Check if the type object is a upcast type.
"""Check if the type object is a upcast type using preset list.
Parameters
----------
Expand All @@ -164,8 +188,7 @@ def is_upcast_type(other):
-------
bool
"""
# Check if class name is in preset list
return other.__name__ in upcast_types
return other in upcast_types


def eq(lhs, rhs, check_all):
Expand Down
12 changes: 7 additions & 5 deletions pint/testsuite/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from pint.unit import UnitsContainer


class FakeWrapper:
# Used in test_upcast_type_rejection_on_creation
def __init__(self, q):
self.q = q


class TestQuantity(QuantityTestCase):

FORCE_NDARRAY = False
Expand Down Expand Up @@ -536,12 +542,8 @@ def test_array_function_warning_on_creation(self):
def test_no_ndarray_coercion_without_numpy(self):
self.assertRaises(ValueError, self.Q_(1, "m").__array__)

@patch("pint.compat.upcast_types", ["FakeWrapper"])
@patch("pint.compat.upcast_types", [FakeWrapper])
def test_upcast_type_rejection_on_creation(self):
class FakeWrapper:
def __init__(self, q):
self.q = q

self.assertRaises(TypeError, self.Q_, FakeWrapper(42), "m")
self.assertEqual(FakeWrapper(self.Q_(42, "m")).q, self.Q_(42, "m"))

Expand Down

0 comments on commit 5ce5c46

Please sign in to comment.